From 7241c49015fa37eff8ce115b06ec11b1bb3e8e42 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Fri, 17 Apr 2026 13:38:43 +0200 Subject: [PATCH 01/44] Unify JRA55 onto generic DatasetBackend; add async PrefetchingBackend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit reworks `src/DataWrangling/` along two related axes: 1. JRA55 is no longer a special-case data source. The bespoke `JRA55NetCDFBackend` struct, `JRA55FieldTimeSeries` constructor, and `InMemory()` backend support are removed. JRA55 now flows through the same `FieldTimeSeries(::Metadata)` / `Field(::Metadatum)` API as ECCO4 / EN4 / WOA / GLORYS, with `inpainting = nothing` selecting the chunked-yearly NetCDF dispatch path. `JRA55NetCDFBackend(N)` is kept as a thin function that returns a `DatasetBackend(N, nothing; inpainting=nothing)` so existing call-sites still construct the right backend. 2. A new `PrefetchingBackend{B<:DatasetBackend}` wraps any `DatasetBackend` and hides the next sliding-window's I/O behind the current window's compute via `Threads.@spawn`. Opt in via `prefetch=true` on `FieldTimeSeries(::Metadata)`, `DatasetRestoring(...)`, or `JRA55PrescribedAtmosphere(...)`. New / removed ------------- - New `src/DataWrangling/dataset_backend.jl`: the existing `DatasetBackend{N,C,I,M}` extracted from `metadata_field_time_series.jl`, plus its constructors / accessors and the generic per-file `set!` used by ECCO4 / EN4 / WOA. - New `src/DataWrangling/prefetching_backend.jl`: the `PrefetchingBackend` wrapper, hot/cold-path `set!`, cyclical-wrap scheduling, property forwarding so `fts.backend.start` etc. continue to address the inner `DatasetBackend`. - New `retrieve_data(::JRA55Metadatum)` (split per dataset type to handle the no-leap calendar issue in multi-year files; see below) so that the generic `Field(::Metadatum)` path produces a correct 2D slice for any JRA55 metadatum. - Removed `JRA55FieldTimeSeries`, the `JRA55NetCDFBackend` struct, and the JRA55-specific `Adapt.adapt_structure`. - Updated tests and examples to the new pattern. Net diff: +487 / −403. Closes #18. ---- Background — why this matters ----------------------------- OMIP simulations on the ORCA grid surfaced periodic wall-time spikes during time stepping: ``` [ Info: iteration: 133680, wall time: 1.780 seconds [ Info: iteration: 133690, wall time: 23.906 seconds <-- spike [ Info: iteration: 133700, wall time: 1.732 seconds ``` Two causes were identified on `ss/omip-prototype`: 1. `set!(fts::JRA55NetCDFFTSMultipleYears)` was opening every yearly NetCDF file in the metadata (~60 for the full 1958–2019 atmosphere) per reload, even files with no overlap with the current window. This added up to ~660 NetCDF opens per reload across 11 atmospheric variables. 2. The remaining ~15 s spike (down from ~24 s after fix 1) is the actual cost of reading ~2 GB of compressed NetCDF across 11 files, and cannot be reduced further without either shrinking the window (more frequent spikes) or hiding the I/O behind compute. On the 1° configuration, this manifests roughly as **~35 s per reload on the cold path** and **~15 s on the hot path** (after staging files to fast scratch and applying the per-window file filtering from `ss/omip-prototype`). Across a year of simulation that is a few percent of total wall time, dominated by I/O serialised against the time step. The first fix (per-window filename filtering + per-file `ftsn_loc`) is ported here as a plain bug fix. The second is what motivates the `PrefetchingBackend`. How prefetching works --------------------- A `PrefetchingBackend` carries an inner `DatasetBackend` plus three mutable fields: a `Task`, a buffer `FieldTimeSeries` (a clone of the main FTS whose `data` array is the prefetch destination), and the absolute `next_start` index that the buffer will hold once its task completes. When Oceananigans calls `set!(fts)` after advancing the window: 1. If the pending prefetch's `next_start` matches the requested `start`, `wait(task)` (typically a no-op because the read finished while the time step was running) then `copyto!(parent(fts.data), parent(buffer_fts.data))` — a memory copy. 2. Otherwise — the cold path on first reload, on checkpointer restart, or after an unexpected window jump — drain any stale task, synchronously load via a one-off clone FTS, then copy. 3. Either way, schedule the next window's load: `Threads.@spawn set!(next_buffer_fts)` with the inner backend re-pointed at `mod1(start + length, length(fts.times))`. The clone-FTS approach (rather than swapping `fts.backend` to the inner DatasetBackend in place) keeps the type of `fts.backend` stable across reloads and lets the spawned `set!` dispatch through the existing JRA55-specific methods without any special-casing. JRA55 calendar caveat (multi-year) ---------------------------------- JRA55 NetCDF files use a `DateTimeNoLeap` (365-day) calendar internally, while `all_dates(::MultiYearJRA55, name)` is a `Dates.DateTime` step range that includes Feb 29 of leap years. `retrieve_data` is therefore split: `retrieve_data(::RepeatYearJRA55Metadatum)` uses position-based indexing (safe — repeat year 1990 is itself non-leap), while `retrieve_data(::MultiYearJRA55Metadatum)` reads the file's time axis and matches by `(Y, M, D, H, min)` components, sidestepping the calendar mismatch entirely. A pre-existing analogous bug in `set!(::JRA55NetCDFFTSMultipleYears)` (`file_times` and `fts.times` diverge across leap years) is **not** addressed here — flag for a follow-up issue. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/inspect_JRA55_data.jl | 7 +- src/DataWrangling/DataWrangling.jl | 2 + src/DataWrangling/JRA55/JRA55.jl | 2 +- .../JRA55/JRA55_field_time_series.jl | 375 +++++------------- src/DataWrangling/JRA55/JRA55_metadata.jl | 45 ++- .../JRA55/JRA55_prescribed_atmosphere.jl | 55 +-- src/DataWrangling/dataset_backend.jl | 100 +++++ .../metadata_field_time_series.jl | 71 +--- src/DataWrangling/prefetching_backend.jl | 129 ++++++ src/DataWrangling/restoring.jl | 13 +- src/NumericalEarth.jl | 1 - test/test_downloading.jl | 8 +- test/test_jra55.jl | 82 +++- 13 files changed, 487 insertions(+), 403 deletions(-) create mode 100644 src/DataWrangling/dataset_backend.jl create mode 100644 src/DataWrangling/prefetching_backend.jl diff --git a/examples/inspect_JRA55_data.jl b/examples/inspect_JRA55_data.jl index 57d4040cc..1c1dc805f 100644 --- a/examples/inspect_JRA55_data.jl +++ b/examples/inspect_JRA55_data.jl @@ -4,8 +4,11 @@ using Oceananigans using Oceananigans.Units using Printf -Qswt = NumericalEarth.JRA55.JRA55FieldTimeSeries(:downwelling_shortwave_radiation) -rht = NumericalEarth.JRA55.JRA55FieldTimeSeries(:relative_humidity) +using NumericalEarth.DataWrangling: Metadata +using NumericalEarth.JRA55: RepeatYearJRA55 + +Qswt = FieldTimeSeries(Metadata(:downwelling_shortwave_radiation; dataset=RepeatYearJRA55()); time_indices_in_memory=8) +rht = FieldTimeSeries(Metadata(:specific_humidity; dataset=RepeatYearJRA55()); time_indices_in_memory=8) function lonlat2xyz(lons::AbstractVector, lats::AbstractVector) x = [cosd(lat) * cosd(lon) for lon in lons, lat in lats] diff --git a/src/DataWrangling/DataWrangling.jl b/src/DataWrangling/DataWrangling.jl index f9b7a0562..552985597 100644 --- a/src/DataWrangling/DataWrangling.jl +++ b/src/DataWrangling/DataWrangling.jl @@ -198,6 +198,8 @@ default_mask_value(dataset) = NaN # Fundamentals include("metadata.jl") include("metadata_field.jl") +include("dataset_backend.jl") +include("prefetching_backend.jl") include("metadata_field_time_series.jl") include("inpainting.jl") include("restoring.jl") diff --git a/src/DataWrangling/JRA55/JRA55.jl b/src/DataWrangling/JRA55/JRA55.jl index 4d703dd3f..9308c665f 100644 --- a/src/DataWrangling/JRA55/JRA55.jl +++ b/src/DataWrangling/JRA55/JRA55.jl @@ -1,6 +1,6 @@ module JRA55 -export JRA55FieldTimeSeries, JRA55PrescribedAtmosphere, RepeatYearJRA55, MultiYearJRA55 +export JRA55PrescribedAtmosphere, RepeatYearJRA55, MultiYearJRA55, JRA55NetCDFBackend using Oceananigans using Oceananigans.Units diff --git a/src/DataWrangling/JRA55/JRA55_field_time_series.jl b/src/DataWrangling/JRA55/JRA55_field_time_series.jl index ea3318175..dc96e5ce9 100644 --- a/src/DataWrangling/JRA55/JRA55_field_time_series.jl +++ b/src/DataWrangling/JRA55/JRA55_field_time_series.jl @@ -4,6 +4,8 @@ using Oceananigans.Grids: AbstractGrid using Oceananigans.OutputReaders: PartlyInMemory using Adapt +import NumericalEarth.DataWrangling: retrieve_data + compute_bounding_nodes(::Nothing, ::Nothing, LH, hnodes) = nothing compute_bounding_nodes(bounds, ::Nothing, LH, hnodes) = bounds @@ -70,31 +72,89 @@ function compute_bounding_indices(longitude, latitude, grid, LX, LY, λc, φc) return i₁, i₂, j₁, j₂, TX end -struct JRA55NetCDFBackend{M} <: AbstractInMemoryBackend{Int} - start :: Int - length :: Int - metadata :: M -end +""" + JRA55NetCDFBackend(length [, metadata]) + JRA55NetCDFBackend(start, length, metadata) + +Backwards-compatible shorthand for a `DatasetBackend` configured for +JRA55-style chunked NetCDF input: multiple time instances per file and no +inpainting (`inpainting = nothing`). Returns a `DatasetBackend` whose +metadata-driven `set!` dispatches to the JRA55 multi-year or repeat-year +methods defined below. +""" +JRA55NetCDFBackend(length) = DatasetBackend(length, nothing; inpainting=nothing) +JRA55NetCDFBackend(length, metadata::Metadata) = DatasetBackend(length, metadata; inpainting=nothing) +JRA55NetCDFBackend(start::Integer, length::Integer) = DatasetBackend(start, length, nothing; inpainting=nothing) +JRA55NetCDFBackend(start::Integer, length::Integer, metadata) = DatasetBackend(start, length, metadata; inpainting=nothing) -Adapt.adapt_structure(to, b::JRA55NetCDFBackend) = JRA55NetCDFBackend(b.start, b.length, nothing) +const JRA55NetCDFFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:DatasetBackend{<:Any, <:Any, <:Any, <:JRA55Metadata}} +const JRA55NetCDFFTSRepeatYear = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:DatasetBackend{<:Any, <:Any, <:Any, <:Metadata{<:RepeatYearJRA55}}} +const JRA55NetCDFFTSMultipleYears = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:DatasetBackend{<:Any, <:Any, <:Any, <:Metadata{<:MultiYearJRA55}}} """ - JRA55NetCDFBackend(length) - -Represents a JRA55 FieldTimeSeries backed by JRA55 native netCDF files. + retrieve_data(metadatum::JRA55Metadatum) + +Read the 2D slice from the JRA55 NetCDF file corresponding to `metadatum`'s +single date. JRA55 files chunk the series by calendar year and use a +`DateTimeNoLeap` (365-day) calendar internally, so the file-local index +must be resolved against either the file's own time axis (the safe path +for `MultiYearJRA55`, which spans real leap years) or the position within +the year's `all_dates` (which is unambiguous for `RepeatYearJRA55`, +because the repeat year — 1990 — is itself non-leap and the file holds +exactly 2920 entries that align 1:1 with `all_dates`). """ -JRA55NetCDFBackend(length, metadata::Metadata) = JRA55NetCDFBackend(1, length, metadata) -JRA55NetCDFBackend(start::Integer, length::Integer) = JRA55NetCDFBackend(start, length, nothing) +function retrieve_data(metadatum::RepeatYearJRA55Metadatum) + path = metadata_path(metadatum) + name = dataset_variable_name(metadatum) + + dates = all_dates(metadatum.dataset, metadatum.name) + file_idx = findfirst(==(metadatum.dates), dates) -# Metadata - agnostic constructor -JRA55NetCDFBackend(length) = JRA55NetCDFBackend(1, length, nothing) + if isnothing(file_idx) + throw(ArgumentError("Date $(metadatum.dates) not found in $(metadatum.dataset) :$(metadatum.name) all_dates.")) + end + + ds = Dataset(path) + data = ds[name][:, :, file_idx] + close(ds) + return data +end -Base.length(backend::JRA55NetCDFBackend) = backend.length -Base.summary(backend::JRA55NetCDFBackend) = string("JRA55NetCDFBackend(", backend.start, ", ", backend.length, ")") +function retrieve_data(metadatum::MultiYearJRA55Metadatum) + path = metadata_path(metadatum) + name = dataset_variable_name(metadatum) -const JRA55NetCDFFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:JRA55NetCDFBackend} -const JRA55NetCDFFTSRepeatYear = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:JRA55NetCDFBackend{<:Metadata{<:RepeatYearJRA55}}} -const JRA55NetCDFFTSMultipleYears = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:JRA55NetCDFBackend{<:Metadata{<:MultiYearJRA55}}} + ds = Dataset(path) + file_dates = ds["time"][:] + file_idx = jra55_no_leap_file_index(file_dates, metadatum.dates) + + if isnothing(file_idx) + close(ds) + throw(ArgumentError(string("Date ", metadatum.dates, + " not found in JRA55 multi-year file ", path, + " (note: JRA55 multi-year files use a no-leap calendar; ", + "Feb 29 of leap years has no corresponding file entry)."))) + end + + data = ds[name][:, :, file_idx] + close(ds) + return data +end + +# Find the file-time index whose calendar components (Y/M/D/H/min) match +# the target date. Calendar-component matching avoids the +# `DateTimeNoLeap` ↔ `DateTime` epoch / leap-day mismatch that would +# otherwise break naive arithmetic-based lookup. +function jra55_no_leap_file_index(file_dates, target) + return findfirst(file_dates) do d + !ismissing(d) && + Dates.year(d) == Dates.year(target) && + Dates.month(d) == Dates.month(target) && + Dates.day(d) == Dates.day(target) && + Dates.hour(d) == Dates.hour(target) && + Dates.minute(d) == Dates.minute(target) + end +end # Note that each file should have the variables # - ds["time"]: time coordinate @@ -152,14 +212,17 @@ end # we need to infer the file name from the metadata and split the data loading function set!(fts::JRA55NetCDFFTSMultipleYears, backend=fts.backend) - metadata = backend.metadata - - filename = metadata.filename - filename = unique(filename) + metadata = backend.metadata name = dataset_variable_name(metadata) start_date = first_date(metadata.dataset, metadata.name) - for file in filename + ftsn = collect(time_indices(fts)) + + # Only open files that actually contain needed time indices + # (metadata.filename maps each time index to its yearly file) + needed_files = unique(getfilename(metadata.filename, n) for n in ftsn) + + for file in needed_files path = joinpath(metadata.dir, file) ds = Dataset(path) @@ -175,12 +238,9 @@ function set!(fts::JRA55NetCDFFTSMultipleYears, backend=fts.backend) file_times[t] = delta end - ftsn = time_indices(fts) - ftsn = collect(ftsn) - # Intersect the time indices with the file times - nn = findall(n -> file_times[n] ∈ fts.times[ftsn], file_indices) - ftsn = findall(n -> fts.times[n] ∈ file_times[nn], ftsn) + nn = findall(n -> file_times[n] ∈ fts.times[ftsn], file_indices) + ftsn_loc = findall(n -> fts.times[n] ∈ file_times[nn], ftsn) if !isempty(nn) # Nodes at the variable location @@ -189,14 +249,12 @@ function set!(fts::JRA55NetCDFFTSMultipleYears, backend=fts.backend) LX, LY, LZ = location(fts) i₁, i₂, j₁, j₂, TX = compute_bounding_indices(nothing, nothing, fts.grid, LX, LY, λc, φc) - if issorted(nn) data = ds[name][i₁:i₂, j₁:j₂, nn] else - # The time indices may be cycling past 1; eg ti = [6, 7, 8, 1]. - # However, DiskArrays does not seem to support loading data with unsorted - # indices. So to handle this, we load the data in chunks, where each chunk's - # indices are sorted, and then glue the data together. + # At the cyclical wrap (end of year 60 → start of year 1), + # file-local indices may be unsorted. DiskArrays requires + # sorted indices, so we load in two sorted chunks. m = findfirst(n -> n == 1, nn) n1 = nn[1:m-1] n2 = nn[m:end] @@ -208,11 +266,11 @@ function set!(fts::JRA55NetCDFFTSMultipleYears, backend=fts.backend) close(ds) - # We need to set the time index for each file - # Find start index corresponding to the underlying data for n in 1:length(nn) - copyto!(interior(fts, :, :, 1, ftsn[n]), data[:, :, n]) + copyto!(interior(fts, :, :, 1, ftsn_loc[n]), data[:, :, n]) end + else + close(ds) end end @@ -220,248 +278,3 @@ function set!(fts::JRA55NetCDFFTSMultipleYears, backend=fts.backend) return nothing end - -new_backend(b::JRA55NetCDFBackend, start, length) = JRA55NetCDFBackend(start, length, b.metadata) - -""" - JRA55FieldTimeSeries(variable_name, architecture=CPU(), FT=Float32; - dataset = RepeatYearJRA55(), - dates = all_JRA55_dates(version), - latitude = nothing, - longitude = nothing, - dir = download_JRA55_cache, - backend = InMemory(), - time_indexing = Cyclical()) - -Return a `FieldTimeSeries` containing atmospheric reanalysis data for `variable_name`, -which describes one of the variables from the Japanese 55-year atmospheric reanalysis -for driving ocean-sea ice models (JRA55-do). The JRA55-do dataset is described by [tsujino2018jra](@citet). - -The `variable_name`s (and their `shortname`s used in the netCDF files) available from the JRA55-do are: -- `:river_freshwater_flux` ("friver") -- `:rain_freshwater_flux` ("prra") -- `:snow_freshwater_flux` ("prsn") -- `:iceberg_freshwater_flux` ("licalvf") -- `:specific_humidity` ("huss") -- `:sea_level_pressure` ("psl") -- `:relative_humidity` ("rhuss") -- `:downwelling_longwave_radiation` ("rlds") -- `:downwelling_shortwave_radiation` ("rsds") -- `:temperature` ("ras") -- `:eastward_velocity` ("uas") -- `:northward_velocity` ("vas") - -Keyword arguments -================= - -- `architecture`: Architecture for the `FieldTimeSeries`. Default: CPU() - -- `dataset`: The data dataset; supported datasets are: `RepeatYearJRA55()` and `MultiYearJRA55()`. - `MultiYearJRA55()` refers to the full length of the JRA55-do dataset; `RepeatYearJRA55()` - refers to the "repeat-year forcing" dataset derived from JRA55-do. Default: `RepeatYearJRA55()`. - - !!! info "Repeat-year forcing" - - For more information about the derivation of the repeat-year forcing dataset, see [stewart2020jra55](@citet). - - The repeat year in `RepeatYearJRA55()` corresponds to May 1st, 1990 - April 30th, 1991. However, the - returned dataset has dates that range from January 1st to December 31st. This implies - that the first 4 months of the `JRA55RepeatYear()` dataset correspond to year 1991 from the JRA55 - reanalysis and the rest 8 months from 1990. - -- `start_date`: The starting date to use for the dataset. Default: `first_date(dataset, variable_name)`. - -- `end_date`: The ending date to use for the dataset. Default: `end_date(dataset, variable_name)`. - -- `dir`: The directory of the data file. Default: `NumericalEarth.JRA55.download_JRA55_cache`. - -- `time_indexing`: The time indexing scheme for the field time series. Default: `Cyclical()`. - -- `latitude`: Guiding latitude bounds for the resulting grid. - Used to slice the data when loading into memory. - Default: nothing, which retains the latitude range of the native grid. - -- `longitude`: Guiding longitude bounds for the resulting grid. - Used to slice the data when loading into memory. - Default: nothing, which retains the longitude range of the native grid. - -- `backend`: Backend for the `FieldTimeSeries`. The two options are: - * `InMemory()`: the whole time series is loaded into memory. - * `JRA55NetCDFBackend(total_time_instances_in_memory)`: only a subset of the time series - is loaded into memory. Default: `InMemory()`. - -References -========== - -- Tsujino et al. (2018). JRA-55 based surface dataset for driving ocean-sea-ice models (JRA55-do), _Ocean Modelling_, **130(1)**, 79-139. - -- Stewart et al. (2020). JRA55-do-based repeat year forcing datasets for driving ocean–sea-ice models, _Ocean Modelling_, **147**, 101557. -""" -function JRA55FieldTimeSeries(variable_name::Symbol, architecture=CPU(), FT=Float32; - dataset = RepeatYearJRA55(), - start_date = first_date(dataset, variable_name), - end_date = last_date(dataset, variable_name), - dir = download_JRA55_cache, - kw...) - - native_dates = all_dates(dataset, variable_name) - dates = compute_native_date_range(native_dates, start_date, end_date) - metadata = Metadata(variable_name; dataset, dates, dir) - - return JRA55FieldTimeSeries(metadata, architecture, FT; kw...) -end - -function JRA55FieldTimeSeries(metadata::JRA55Metadata, architecture=CPU(), FT=Float32; - latitude = nothing, - longitude = nothing, - backend = InMemory(), - time_indexing = Cyclical()) - - # Cannot use `TotallyInMemory` backend with MultiYearJRA55 dataset - if metadata.dataset isa MultiYearJRA55 && backend isa TotallyInMemory - msg = string("The `InMemory` backend is not supported for the MultiYearJRA55 dataset.") - throw(ArgumentError(msg)) - end - - # First thing: we download the dataset! - download_dataset(metadata) - - # Regularize the backend in case of `JRA55NetCDFBackend` - if backend isa JRA55NetCDFBackend - if backend.metadata isa Nothing - backend = JRA55NetCDFBackend(backend.length, metadata) - end - - if backend.length > length(metadata) - backend = JRA55NetCDFBackend(backend.start, length(metadata), metadata) - end - end - - # Unpack metadata details - dataset = metadata.dataset - name = metadata.name - time_indices = JRA55_time_indices(dataset, metadata.dates, name) - - # Change the metadata to reflect the actual time indices - dates = all_dates(dataset, name)[time_indices] - metadata = Metadata(metadata.name; dataset=metadata.dataset, dates, dir=metadata.dir) - - shortname = dataset_variable_name(metadata) - variable_name = metadata.name - - filepath = metadata_path(metadata) # Might be multiple paths!!! - filepath = filepath isa AbstractArray ? first(filepath) : filepath - - # OnDisk backends do not support time interpolation! - # Disallow OnDisk for JRA55 dataset loading - if ((backend isa InMemory) && !isnothing(backend.length)) || backend isa OnDisk - msg = string("We cannot load the JRA55 dataset with a $(backend) backend. Use `InMemory()` or `JRA55NetCDFBackend(N)` instead.") - throw(ArgumentError(msg)) - end - - if !(variable_name ∈ JRA55_variable_names) - variable_strs = Tuple(" - :$name \n" for name in JRA55_variable_names) - variables_msg = prod(variable_strs) - - msg = string("The variable :$variable_name is not provided by the JRA55-do dataset!", '\n', - "The variables provided by the JRA55-do dataset are:", '\n', - variables_msg) - - throw(ArgumentError(msg)) - end - - # Record some important user decisions - totally_in_memory = backend isa TotallyInMemory - - # Determine default time indices - if totally_in_memory - # In this case, the whole time series is in memory. - # Either the time series is short, or we are doing a limited-area - # simulation, like in a single column. So, we conservatively - # set a default `time_indices = 1:2`. - time_indices_in_memory = time_indices - native_fts_architecture = architecture - else - # In this case, part or all of the time series will be stored in a file. - # Note: if the user has provided a grid, we will have to preprocess the - # .nc JRA55 data into a .jld2 file. In this case, `time_indices` refers - # to the time_indices that we will preprocess; - # by default we choose all of them. The architecture is only the - # architecture used for preprocessing, which typically will be CPU() - # even if we would like the final FieldTimeSeries on the GPU. - time_indices_in_memory = 1:length(backend) - native_fts_architecture = architecture - end - - ds = Dataset(filepath) - - # Note that each file should have the variables - # - ds["time"]: time coordinate - # - ds["lon"]: longitude at the location of the variable - # - ds["lat"]: latitude at the location of the variable - # - ds["lon_bnds"]: bounding longitudes between which variables are averaged - # - ds["lat_bnds"]: bounding latitudes between which variables are averaged - # - ds[shortname]: the variable data - - # Nodes at the variable location - λc = ds["lon"][:] - φc = ds["lat"][:] - - # Interfaces for the "native" JRA55 grid - λn = Array(ds["lon_bnds"][1, :]) - φn = Array(ds["lat_bnds"][1, :]) - - # The netCDF coordinates lon_bnds and lat_bnds do not include - # the last interfaces, so we push them here. - push!(φn, 90) - push!(λn, λn[1] + 360) - - i₁, i₂, j₁, j₂, TX = compute_bounding_indices(longitude, latitude, nothing, Center, Center, λc, φc) - - λr = λn[i₁:i₂+1] - φr = φn[j₁:j₂+1] - Nrx = length(λr) - 1 - Nry = length(φr) - 1 - close(ds) - - N = (Nrx, Nry) - H = min.(N, (3, 3)) - - JRA55_native_grid = LatitudeLongitudeGrid(native_fts_architecture, FT; - halo = H, - size = N, - longitude = λr, - latitude = φr, - topology = (TX, Bounded, Flat)) - - boundary_conditions = FieldBoundaryConditions(JRA55_native_grid, (Center(), Center(), nothing)) - start_time = first_date(metadata.dataset, metadata.name) - times = native_times(metadata; start_time) - - if backend isa JRA55NetCDFBackend - fts = FieldTimeSeries{Center, Center, Nothing}(JRA55_native_grid, times; - backend, - time_indexing, - boundary_conditions, - path = filepath, - name = shortname) - - set!(fts) - return fts - else - fts = FieldTimeSeries{Center, Center, Nothing}(JRA55_native_grid, times; - time_indexing, - backend, - boundary_conditions) - - # Fill the data in a GPU-friendly manner - ds = Dataset(filepath) - data = ds[shortname][i₁:i₂, j₁:j₂, time_indices_in_memory] - close(ds) - - copyto!(interior(fts, :, :, 1, :), data) - fill_halo_regions!(fts) - - return fts - end -end diff --git a/src/DataWrangling/JRA55/JRA55_metadata.jl b/src/DataWrangling/JRA55/JRA55_metadata.jl index a32b5cc95..f4cf58d26 100644 --- a/src/DataWrangling/JRA55/JRA55_metadata.jl +++ b/src/DataWrangling/JRA55/JRA55_metadata.jl @@ -5,29 +5,51 @@ using Downloads using Oceananigans.DistributedComputations using NumericalEarth.DataWrangling -using NumericalEarth.DataWrangling: Metadata, metadata_path, download_progress, AnyDateTime +using NumericalEarth.DataWrangling: Metadata, metadata_path, download_progress, AnyDateTime, DatasetBackend import Dates: year, month, day import Oceananigans.Fields: set! import Base import Oceananigans.Fields: set!, location -import NumericalEarth.DataWrangling: all_dates, metadata_filename, build_filename, download_dataset, default_download_directory, available_variables +import NumericalEarth.DataWrangling: all_dates, + metadata_filename, + build_filename, + download_dataset, + default_download_directory, + dataset_variable_name, + available_variables, + default_inpainting, + getfilename, + z_interfaces, + longitude_interfaces, + latitude_interfaces -struct MultiYearJRA55 end -struct RepeatYearJRA55 end +abstract type JRA55Dataset end -const JRA55Metadata{D} = Metadata{<:Union{<:MultiYearJRA55, <:RepeatYearJRA55}, D} -const JRA55Metadatum = Metadatum{<:Union{<:MultiYearJRA55, <:RepeatYearJRA55}} +struct MultiYearJRA55 <: JRA55Dataset end +struct RepeatYearJRA55 <: JRA55Dataset end -default_download_directory(::Union{<:MultiYearJRA55, <:RepeatYearJRA55}) = download_JRA55_cache +const JRA55Metadata{D} = Metadata{<:JRA55Dataset, D} +const JRA55Metadatum = Metadatum{<:JRA55Dataset} -Base.size(data::JRA55Metadata) = (640, 320, length(data.dates)) -Base.size(::JRA55Metadatum) = (640, 320, 1) +const RepeatYearJRA55Metadatum = Metadatum{<:RepeatYearJRA55} +const MultiYearJRA55Metadatum = Metadatum{<:MultiYearJRA55} + +default_download_directory(::JRA55Dataset) = download_JRA55_cache + +Base.size(::JRA55Dataset, variable) = (640, 320, 1) + +z_interfaces(::JRA55Metadata) = (0, 10) +longitude_interfaces(::JRA55Metadata) = (0, 360) +latitude_interfaces(::JRA55Metadata) = (-90, 90) # JRA55 is a spatially 2D dataset is_three_dimensional(data::JRA55Metadata) = false +# Never inpaint JRA55 +default_inpainting(::JRA55Metadata) = nothing + # The whole range of dates in the different dataset datasets # NOTE! rivers and icebergs have a different frequency! (typical JRA55 data is three-hourly while rivers and icebergs are daily) function all_dates(::RepeatYearJRA55, name) @@ -41,7 +63,7 @@ end all_dates(::MultiYearJRA55, name) = JRA55_multiple_year_dates[name] # Fallback, if we not provide the name, take the highest frequency -all_dates(dataset::Union{<:MultiYearJRA55, <:RepeatYearJRA55}) = all_dates(dataset, :temperature) +all_dates(dataset::JRA55Dataset) = all_dates(dataset, :temperature) # Valid for all JRA55 datasets function JRA55_time_indices(dataset, dates, name) @@ -88,8 +110,7 @@ end dataset_variable_name(data::JRA55Metadata) = JRA55_dataset_variable_names[data.name] location(::JRA55Metadata) = (Center, Center, Center) -available_variables(::MultiYearJRA55) = JRA55_variable_names -available_variables(::RepeatYearJRA55) = JRA55_variable_names +available_variables(::JRA55Dataset) = JRA55_variable_names # A list of all variables provided in the JRA55 dataset: JRA55_variable_names = (:river_freshwater_flux, diff --git a/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl b/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl index a00edaaf6..a5222411a 100644 --- a/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl +++ b/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl @@ -1,56 +1,65 @@ const AA = Oceananigans.Architectures.AbstractArchitecture -JRA55PrescribedAtmosphere(arch::Distributed, FT = Float32; kw...) = +JRA55PrescribedAtmosphere(arch::Distributed; kw...) = JRA55PrescribedAtmosphere(child_architecture(arch); kw...) - """ - JRA55PrescribedAtmosphere([architecture = CPU(), FT = Float32]; + JRA55PrescribedAtmosphere([architecture = CPU()]; dataset = RepeatYearJRA55(), start_date = first_date(dataset, :temperature), end_date = last_date(dataset, :temperature), - backend = JRA55NetCDFBackend(10), + dir = download_JRA55_cache, + time_indices_in_memory = 10, time_indexing = Cyclical(), + prefetch = false, surface_layer_height = 10, # meters include_rivers_and_icebergs = false, other_kw...) Return a [`PrescribedAtmosphere`](@ref) representing JRA55 reanalysis data. -The atmospheric data will be held in `JRA55FieldTimeSeries` objects containing. -For a detailed description of the keyword arguments, see the [`JRA55FieldTimeSeries`](@ref) constructor. +Each atmospheric field is constructed via `FieldTimeSeries(::JRA55Metadata)`, +which uses a `DatasetBackend` parameterised by JRA55 metadata so that the +JRA55-specific `set!` (chunked-yearly NetCDF) is dispatched. With +`prefetch = true` each variable's next sliding window is loaded +asynchronously on a background thread so the reload spike (~15 s on a +240-step window across 9 variables) is hidden behind compute. """ -function JRA55PrescribedAtmosphere(architecture = CPU(), FT = Float32; +function JRA55PrescribedAtmosphere(architecture = CPU(); dataset = RepeatYearJRA55(), start_date = first_date(dataset, :temperature), end_date = last_date(dataset, :temperature), - backend = JRA55NetCDFBackend(10), + dir = download_JRA55_cache, + time_indices_in_memory = 10, time_indexing = Cyclical(), + prefetch = false, surface_layer_height = 10, # meters include_rivers_and_icebergs = false, other_kw...) - kw = (; time_indexing, backend, start_date, end_date, dataset) + kw = (; time_indexing, time_indices_in_memory, prefetch) kw = merge(kw, other_kw) - ua = JRA55FieldTimeSeries(:eastward_velocity, architecture, FT; kw...) - va = JRA55FieldTimeSeries(:northward_velocity, architecture, FT; kw...) - Ta = JRA55FieldTimeSeries(:temperature, architecture, FT; kw...) - qa = JRA55FieldTimeSeries(:specific_humidity, architecture, FT; kw...) - pa = JRA55FieldTimeSeries(:sea_level_pressure, architecture, FT; kw...) - Fra = JRA55FieldTimeSeries(:rain_freshwater_flux, architecture, FT; kw...) - Fsn = JRA55FieldTimeSeries(:snow_freshwater_flux, architecture, FT; kw...) - ℐꜜˡʷ = JRA55FieldTimeSeries(:downwelling_longwave_radiation, architecture, FT; kw...) - ℐꜜˢʷ = JRA55FieldTimeSeries(:downwelling_shortwave_radiation, architecture, FT; kw...) + jra55_fts(name) = FieldTimeSeries(Metadata(name; dataset, start_date, end_date, dir), architecture; kw...) + + ua = jra55_fts(:eastward_velocity) + va = jra55_fts(:northward_velocity) + Ta = jra55_fts(:temperature) + qa = jra55_fts(:specific_humidity) + pa = jra55_fts(:sea_level_pressure) + Fra = jra55_fts(:rain_freshwater_flux) + Fsn = jra55_fts(:snow_freshwater_flux) + ℐꜜˡʷ = jra55_fts(:downwelling_longwave_radiation) + ℐꜜˢʷ = jra55_fts(:downwelling_shortwave_radiation) freshwater_flux = (rain = Fra, snow = Fsn) - # Remember that rivers and icebergs are on a different grid and have - # a different frequency than the rest of the JRA55 data. We use `PrescribedAtmospheres` - # "auxiliary_freshwater_flux" feature to represent them. + # Rivers and icebergs are on a different grid and have a different + # frequency than the rest of the JRA55 data. We use the + # PrescribedAtmosphere `auxiliary_freshwater_flux` feature for them. if include_rivers_and_icebergs - Fri = JRA55FieldTimeSeries(:river_freshwater_flux, architecture; kw...) - Fic = JRA55FieldTimeSeries(:iceberg_freshwater_flux, architecture; kw...) + Fri = jra55_fts(:river_freshwater_flux) + Fic = jra55_fts(:iceberg_freshwater_flux) auxiliary_freshwater_flux = (rivers = Fri, icebergs = Fic) else auxiliary_freshwater_flux = nothing diff --git a/src/DataWrangling/dataset_backend.jl b/src/DataWrangling/dataset_backend.jl new file mode 100644 index 000000000..fa1dac439 --- /dev/null +++ b/src/DataWrangling/dataset_backend.jl @@ -0,0 +1,100 @@ +using Oceananigans.OutputReaders: Cyclical, AbstractInMemoryBackend, FlavorOfFTS, time_indices + +import Oceananigans.OutputReaders: new_backend +import Oceananigans.Fields: set! + +@inline instantiate(T::DataType) = T() +@inline instantiate(T) = T + +""" + DatasetBackend{N, C, I, M} <: AbstractInMemoryBackend{Int} + +In-memory backend for a `FieldTimeSeries` backed by a dataset whose metadata +maps each in-memory time index to a file (or subset of a file) on disk. The +backend carries + +- `start`, `length`: sliding-window extents into the metadata +- `inpainting`: inpainting algorithm used when reading per-file datasets + (e.g. `NearestNeighborInpainting`); `nothing` for datasets whose native + NetCDF already covers the whole target grid (e.g. JRA55). +- `metadata`: the dataset metadata — its dataset type parameterises `set!` + dispatch, so per-file and chunked (multi-time-per-file) datasets can + coexist under the same backend. + +Type parameters `N` (on-native-grid) and `C` (cache-inpainted-data) are +flags hoisted into the type so that dispatch and `Adapt.adapt_structure` +can act on them without allocating. +""" +struct DatasetBackend{N, C, I, M} <: AbstractInMemoryBackend{Int} + start :: Int + length :: Int + inpainting :: I + metadata :: M + + function DatasetBackend{N, C}(start::Int, length::Int, inpainting, metadata) where {N, C} + M = typeof(metadata) + I = typeof(inpainting) + return new{N, C, I, M}(start, length, inpainting, metadata) + end +end + +Adapt.adapt_structure(to, b::DatasetBackend{N, C}) where {N, C} = + DatasetBackend{N, C}(b.start, b.length, nothing, nothing) + +""" + DatasetBackend(length, metadata; + on_native_grid = false, + cache_inpainted_data = false, + inpainting = NearestNeighborInpainting(Inf)) + DatasetBackend(start, length, metadata; ...) + +Construct a `DatasetBackend` holding `length` in-memory time indices starting +at `start` (default `1`). `inpainting = nothing` selects the dispatch path +for datasets whose files hold multiple time instances (e.g. chunked NetCDF +like JRA55), where per-file inpainting is not applicable. +""" +function DatasetBackend(length, metadata; + on_native_grid = false, + cache_inpainted_data = false, + inpainting = NearestNeighborInpainting(Inf)) + + return DatasetBackend{on_native_grid, cache_inpainted_data}(1, length, inpainting, metadata) +end + +function DatasetBackend(start::Integer, length::Integer, metadata; + on_native_grid = false, + cache_inpainted_data = false, + inpainting = NearestNeighborInpainting(Inf)) + + return DatasetBackend{on_native_grid, cache_inpainted_data}(start, length, inpainting, metadata) +end + +Base.length(backend::DatasetBackend) = backend.length +Base.summary(backend::DatasetBackend) = string("DatasetBackend(", backend.start, ", ", backend.length, ")") + +new_backend(b::DatasetBackend{native, cache_data}, start, length) where {native, cache_data} = + DatasetBackend{native, cache_data}(start, length, b.inpainting, b.metadata) + +on_native_grid(::DatasetBackend{native}) where native = native +cache_inpainted_data(::DatasetBackend{native, cache_data}) where {native, cache_data} = cache_data + +const DatasetFieldTimeSeries{N} = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:DatasetBackend{N}} where N + +# Default per-file set! — each metadata index corresponds to its own file. +# Used by datasets whose native files hold one time instance per file, such +# as ECCO4 / EN4 / WOA monthly climatologies. Datasets whose files hold +# multiple time instances (e.g. JRA55) dispatch on a more specific backend +# signature keyed off the metadata's dataset type. +function set!(fts::DatasetFieldTimeSeries, backend=fts.backend) + inpainting = backend.inpainting + cache_data = cache_inpainted_data(backend) + + for t in time_indices(fts) + metadatum = @inbounds backend.metadata[t] + set!(fts[t], metadatum; inpainting, cache_inpainted_data=cache_data) + end + + fill_halo_regions!(fts) + + return nothing +end diff --git a/src/DataWrangling/metadata_field_time_series.jl b/src/DataWrangling/metadata_field_time_series.jl index 0bfed83af..7a8bc084e 100644 --- a/src/DataWrangling/metadata_field_time_series.jl +++ b/src/DataWrangling/metadata_field_time_series.jl @@ -1,73 +1,8 @@ using Oceananigans.Architectures: AbstractArchitecture using Oceananigans.Grids: AbstractGrid using Oceananigans.Fields: interpolate! -using Oceananigans.OutputReaders: Cyclical, AbstractInMemoryBackend, FlavorOfFTS, time_indices -import Oceananigans.OutputReaders: new_backend, update_field_time_series!, FieldTimeSeries - -@inline instantiate(T::DataType) = T() -@inline instantiate(T) = T - -struct DatasetBackend{N, C, I, M} <: AbstractInMemoryBackend{Int} - start :: Int - length :: Int - inpainting :: I - metadata :: M - - function DatasetBackend{N, C}(start::Int, length::Int, inpainting, metadata) where {N, C} - M = typeof(metadata) - I = typeof(inpainting) - return new{N, C, I, M}(start, length, inpainting, metadata) - end -end - -Adapt.adapt_structure(to, b::DatasetBackend{N, C}) where {N, C} = - DatasetBackend{N, C}(b.start, b.length, nothing, nothing) - -""" - DatasetBackend(length, metadata; - on_native_grid = false, - cache_inpainted_data = false, - inpainting = NearestNeighborInpainting(Inf)) - -Represent a FieldTimeSeries backed by the backend that corresponds to the -dataset with `metadata` (e.g., netCDF). Each time instance is stored in an -individual file. -""" -function DatasetBackend(length, metadata; - on_native_grid = false, - cache_inpainted_data = false, - inpainting = NearestNeighborInpainting(Inf)) - - return DatasetBackend{on_native_grid, cache_inpainted_data}(1, length, inpainting, metadata) -end - -Base.length(backend::DatasetBackend) = backend.length -Base.summary(backend::DatasetBackend) = string("DatasetBackend(", backend.start, ", ", backend.length, ")") - -new_backend(b::DatasetBackend{native, cache_data}, start, length) where {native, cache_data} = - DatasetBackend{native, cache_data}(start, length, b.inpainting, b.metadata) - -on_native_grid(::DatasetBackend{native}) where native = native -cache_inpainted_data(::DatasetBackend{native, cache_data}) where {native, cache_data} = cache_data - -const DatasetFieldTimeSeries{N} = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:DatasetBackend{N}} where N - -function set!(fts::DatasetFieldTimeSeries) - backend = fts.backend - inpainting = backend.inpainting - cache_data = cache_inpainted_data(backend) - - for t in time_indices(fts) - # Set each element of the time-series to the associated file - metadatum = @inbounds backend.metadata[t] - set!(fts[t], metadatum; inpainting, cache_inpainted_data=cache_data) - end - - fill_halo_regions!(fts) - - return nothing -end +import Oceananigans.OutputReaders: update_field_time_series!, FieldTimeSeries """ FieldTimeSeries(metadata::Metadata [, arch_or_grid=CPU() ]; @@ -110,13 +45,15 @@ function FieldTimeSeries(metadata::Metadata, grid::AbstractGrid; time_indices_in_memory = 2, time_indexing = Cyclical(), inpainting = default_inpainting(metadata), - cache_inpainted_data = true) + cache_inpainted_data = true, + prefetch = false) # Make sure all the required individual files are downloaded download_dataset(metadata) inpainting isa Int && (inpainting = NearestNeighborInpainting(inpainting)) backend = DatasetBackend(time_indices_in_memory, metadata; on_native_grid, inpainting, cache_inpainted_data) + prefetch && (backend = PrefetchingBackend(backend)) times = native_times(metadata) loc = LX, LY, LZ = location(metadata) diff --git a/src/DataWrangling/prefetching_backend.jl b/src/DataWrangling/prefetching_backend.jl new file mode 100644 index 000000000..6f1c864d8 --- /dev/null +++ b/src/DataWrangling/prefetching_backend.jl @@ -0,0 +1,129 @@ +using Oceananigans.OutputReaders: AbstractInMemoryBackend, FlavorOfFTS, FieldTimeSeries +using Oceananigans.Fields: location, instantiated_location + +import Oceananigans.OutputReaders: new_backend +import Oceananigans.Fields: set! + +""" + PrefetchingBackend{B<:DatasetBackend} <: AbstractInMemoryBackend{Int} + +Wrapper around a `DatasetBackend` that hides the next window's I/O behind +the current window's compute by reading into a *buffer* `FieldTimeSeries` +on a background `Task`. The next call to `set!` either copies the +already-loaded buffer into the main FTS (hot path) or falls back to a +synchronous read (cold path), then schedules the prefetch for the window +after that. + +Fields +------ + +- `inner`: the wrapped `DatasetBackend`. Carries the `start`/`length` + window, metadata, inpainting setting — the cold-path read and the + prefetched read both dispatch through it. +- `task`: the in-flight prefetch `Task`, or `nothing`. +- `buffer_fts`: a clone of the main `FieldTimeSeries` whose `data` array + is the destination of the prefetch read. Constructed before the task is + spawned so that the spawn does not race on FTS construction. +- `next_start`: the absolute time index at which `buffer_fts` will start + once its prefetch task completes — compared against the requested + `start` on the next `set!` call to decide hot vs. cold path. +""" +mutable struct PrefetchingBackend{B<:DatasetBackend} <: AbstractInMemoryBackend{Int} + inner :: B + task :: Union{Task, Nothing} + buffer_fts :: Any # erased: a `FieldTimeSeries` whose concrete type + # depends on grid/loc/inpainting and is only set + # at the first `set!`. Erasing it keeps the + # backend's type stable across reloads so that + # Oceananigans' `new_backend` round-trip type-checks. + next_start :: Int +end + +PrefetchingBackend(inner::DatasetBackend) = PrefetchingBackend{typeof(inner)}(inner, nothing, nothing, 0) + +# Forward properties so `fts.backend.start`, `fts.backend.metadata`, etc. +# continue to work for downstream code that pokes at the inner backend. +function Base.getproperty(p::PrefetchingBackend, name::Symbol) + if name in (:inner, :task, :buffer_fts, :next_start) + return getfield(p, name) + else + return getproperty(getfield(p, :inner), name) + end +end + +Base.length(p::PrefetchingBackend) = length(p.inner) + +Base.summary(p::PrefetchingBackend) = + string("PrefetchingBackend(", p.inner.start, ", ", p.inner.length, + "; pending_prefetch=", !isnothing(getfield(p, :task)), ")") + +# When Oceananigans rolls the in-memory window forward it calls +# `new_backend(b, start, length)`. Forward the call to the inner backend +# and preserve the prefetch state — the start passed in here is the new +# window's start, which `set!(fts::PrefetchingFTS)` will use to decide if +# the pending prefetch already covers it. +new_backend(p::PrefetchingBackend, start, length) = + PrefetchingBackend(new_backend(p.inner, start, length), + getfield(p, :task), + getfield(p, :buffer_fts), + getfield(p, :next_start)) + +# Dropping prefetch state on adapt — `Task` and the buffer FTS aren't +# meaningful on a different architecture, and they can't be cleanly +# serialised across a CPU↔GPU boundary. +Adapt.adapt_structure(to, p::PrefetchingBackend) = Adapt.adapt(to, p.inner) + +const PrefetchingFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:PrefetchingBackend} + +# Build a clone FTS whose `.backend` is the supplied (inner) DatasetBackend +# and whose `.data` is freshly allocated. The clone shares grid, times, +# location, time_indexing and boundary conditions with `fts`, so writing +# its interior is layout-compatible with the main FTS. +function buffer_field_time_series(fts, inner_backend) + LX, LY, LZ = location(fts) + return FieldTimeSeries{LX, LY, LZ}(fts.grid, fts.times; + backend = inner_backend, + time_indexing = fts.time_indexing, + boundary_conditions = fts.boundary_conditions) +end + +function set!(fts::PrefetchingFTS, backend=fts.backend) + needed_start = getfield(backend, :inner).start + + pending_task = getfield(backend, :task) + pending_fts = getfield(backend, :buffer_fts) + pending_start = getfield(backend, :next_start) + + if !isnothing(pending_task) && pending_start == needed_start && !isnothing(pending_fts) + # Hot path — wait on the prefetch (likely already done) and copy + # the buffer FTS into the main FTS. + wait(pending_task) + copyto!(parent(fts.data), parent(pending_fts.data)) + else + # Cold path — drain any stale prefetch, then synchronously load + # via a one-off buffer FTS. We can't dispatch the JRA55-specific + # `set!` directly on `fts` because `fts.backend` is the + # PrefetchingBackend, not the inner one; so we round-trip through + # a clone FTS whose `.backend` is the inner. + if !isnothing(pending_task) + wait(pending_task) + end + + cold_fts = buffer_field_time_series(fts, getfield(backend, :inner)) + set!(cold_fts) + copyto!(parent(fts.data), parent(cold_fts.data)) + end + + # Kick off the next prefetch (cyclically wrapping at the end of times). + Nm = length(getfield(backend, :inner)) + Nt = length(fts.times) + next_start = mod1(needed_start + Nm, Nt) + next_inner = new_backend(getfield(backend, :inner), next_start, Nm) + next_fts = buffer_field_time_series(fts, next_inner) + + setfield!(backend, :next_start, next_start) + setfield!(backend, :buffer_fts, next_fts) + setfield!(backend, :task, Threads.@spawn set!(next_fts)) + + return nothing +end diff --git a/src/DataWrangling/restoring.jl b/src/DataWrangling/restoring.jl index 918edfae0..fe1ea996f 100644 --- a/src/DataWrangling/restoring.jl +++ b/src/DataWrangling/restoring.jl @@ -187,6 +187,13 @@ Keyword Arguments - `cache_inpainted_data`: If `true`, the data is cached to disk after inpainting for later retrieving. Default: `true`. + +- `prefetch`: If `true`, the next sliding window is loaded asynchronously + on a background thread (`Threads.@spawn`) so the I/O cost of + the next reload overlaps the current window's compute. The + hot-path reload becomes a memory copy from the prefetched + buffer; on a cache miss (e.g. checkpointer restart) the cold + path falls back to a synchronous read. Default: `false`. """ function DatasetRestoring(metadata::Metadata, arch_or_grid = CPU(); @@ -195,7 +202,8 @@ function DatasetRestoring(metadata::Metadata, time_indices_in_memory = default_time_indices_in_memory(metadata), time_indexing = Cyclical(), inpainting = NearestNeighborInpainting(Inf), - cache_inpainted_data = true) + cache_inpainted_data = true, + prefetch = false) download_dataset(metadata) @@ -203,7 +211,8 @@ function DatasetRestoring(metadata::Metadata, time_indices_in_memory, time_indexing, inpainting, - cache_inpainted_data) + cache_inpainted_data, + prefetch) arch = architecture(fts) mask = on_architecture(arch, mask) diff --git a/src/NumericalEarth.jl b/src/NumericalEarth.jl index 1421eadc5..ce0af8229 100644 --- a/src/NumericalEarth.jl +++ b/src/NumericalEarth.jl @@ -43,7 +43,6 @@ export first_date, last_date, all_dates, - JRA55FieldTimeSeries, LinearlyTaperedPolarMask, DatasetRestoring, ocean_simulation, diff --git a/test/test_downloading.jl b/test/test_downloading.jl index 7fcd841ff..ff3f1f723 100644 --- a/test/test_downloading.jl +++ b/test/test_downloading.jl @@ -8,11 +8,11 @@ include("download_utils.jl") datum = Metadatum(name; dataset=JRA55.RepeatYearJRA55()) filepath = metadata_path(datum) - fts = download_dataset_with_fallback(filepath; dataset_name="JRA55 $name") do - NumericalEarth.JRA55.JRA55FieldTimeSeries(name; backend=NumericalEarth.JRA55.JRA55NetCDFBackend(2)) + download_dataset_with_fallback(filepath; dataset_name="JRA55 $name") do + FieldTimeSeries(Metadata(name; dataset=NumericalEarth.JRA55.RepeatYearJRA55()); time_indices_in_memory=2) end - @test isfile(fts.path) - rm(fts.path; force=true) + @test isfile(filepath) + rm(filepath; force=true) end end diff --git a/test/test_jra55.jl b/test/test_jra55.jl index d562362ab..ea9c77f3a 100644 --- a/test/test_jra55.jl +++ b/test/test_jra55.jl @@ -15,7 +15,7 @@ using NumericalEarth.DataWrangling: compute_native_date_range dates = NumericalEarth.DataWrangling.all_dates(JRA55.RepeatYearJRA55(), test_name) end_date = dates[3] - JRA55_fts = JRA55FieldTimeSeries(test_name, arch; end_date) + JRA55_fts = FieldTimeSeries(Metadata(test_name; dataset=JRA55.RepeatYearJRA55(), end_date), arch) test_filename = joinpath(download_JRA55_cache, "RYF.rsds.1990_1991.nc") @test JRA55_fts isa FieldTimeSeries @@ -41,8 +41,8 @@ using NumericalEarth.DataWrangling: compute_native_date_range @info "Testing Cyclical time_indices for JRA55 data on $A..." Nb = 4 - backend = JRA55NetCDFBackend(Nb) - netcdf_JRA55_fts = JRA55FieldTimeSeries(test_name, arch; backend) + netcdf_JRA55_fts = FieldTimeSeries(Metadata(test_name; dataset=JRA55.RepeatYearJRA55()), arch; + time_indices_in_memory=Nb) Nt = length(netcdf_JRA55_fts.times) @test Oceananigans.OutputReaders.time_indices(netcdf_JRA55_fts) == (1, 2, 3, 4) @@ -56,6 +56,40 @@ using NumericalEarth.DataWrangling: compute_native_date_range f₁′ = view(parent(netcdf_JRA55_fts), :, :, 1, 4) f₁′ = Array(f₁′) @test f₁ == f₁′ + + @info "Testing PrefetchingBackend on $A for $test_name..." + # Build a reference (cold) FTS and a prefetching FTS over the same + # window, then drive each through several reloads. After every + # reload the parent data of the prefetching FTS must be byte- + # identical to the reference. The first reload exercises the cold + # fallback (no prior prefetch); subsequent reloads exercise the + # hot path; the wrap from `Nt-3..Nt` back to `1..Nb` exercises the + # cyclical prefetch logic (`mod1(start+Nm, Nt)`). + ref_fts = FieldTimeSeries(Metadata(test_name; dataset=JRA55.RepeatYearJRA55()), arch; + time_indices_in_memory=Nb) + pf_fts = FieldTimeSeries(Metadata(test_name; dataset=JRA55.RepeatYearJRA55()), arch; + time_indices_in_memory=Nb, prefetch=true) + + @test pf_fts.backend isa NumericalEarth.DataWrangling.PrefetchingBackend + @test parent(pf_fts.data) == parent(ref_fts.data) # cold load alignment + @test pf_fts.backend.next_start == Nb + 1 # next prefetch scheduled + + # Reload sequence: + # * Nb+1, 2Nb+1 → straight hot-path advances + # * Nt-Nb+1 → places the next prefetch's window across + # the end-of-times boundary, exercising the + # `mod1(start+Nm, Nt)` wrap when scheduling + # the prefetch + # * 1 → consumes that wrapped prefetch (hot path) + # at the start of the cycle + for next_start in (Nb + 1, 2Nb + 1, Nt - Nb + 1, 1) + ref_fts.backend = Oceananigans.OutputReaders.new_backend(ref_fts.backend, next_start, Nb) + pf_fts.backend = Oceananigans.OutputReaders.new_backend(pf_fts.backend, next_start, Nb) + set!(ref_fts) + set!(pf_fts) + @test parent(pf_fts.data) == parent(ref_fts.data) + @test pf_fts.backend.next_start == mod1(next_start + Nb, Nt) + end end @info "Testing interpolate_field_time_series! on $A..." @@ -63,7 +97,7 @@ using NumericalEarth.DataWrangling: compute_native_date_range name = :downwelling_shortwave_radiation dates = NumericalEarth.DataWrangling.all_dates(JRA55.RepeatYearJRA55(), name) end_date = dates[3] - JRA55_fts = JRA55FieldTimeSeries(name, arch; end_date) + JRA55_fts = FieldTimeSeries(Metadata(name; dataset=JRA55.RepeatYearJRA55(), end_date), arch) # Make target grid and field resolution = 1 # degree, eg 1/4 @@ -116,13 +150,12 @@ using NumericalEarth.DataWrangling: compute_native_date_range ##### JRA55 prescribed atmosphere ##### - backend = JRA55NetCDFBackend(2) - atmosphere = JRA55PrescribedAtmosphere(arch; backend, include_rivers_and_icebergs=false) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=2, include_rivers_and_icebergs=false) @test atmosphere isa PrescribedAtmosphere @test isnothing(atmosphere.auxiliary_freshwater_flux) # Test that rivers and icebergs are included in the JRA55 data with the correct frequency - atmosphere = JRA55PrescribedAtmosphere(arch; backend, include_rivers_and_icebergs=true) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=2, include_rivers_and_icebergs=true) @test haskey(atmosphere.auxiliary_freshwater_flux, :rivers) @test haskey(atmosphere.auxiliary_freshwater_flux, :icebergs) @@ -140,8 +173,6 @@ using NumericalEarth.DataWrangling: compute_native_date_range start_date = DateTime("1959-01-01T00:00:00") - 15 * Day(1) # sometime in 1958 end_date = DateTime("1959-01-01T00:00:00") + 85 * Day(1) # sometime in 1959 - backend = JRA55NetCDFBackend(10) - # Use a temporary directory so different architectures don't clash mktempdir("./") do dir # Compute expected file paths so we can fall back to artifacts if needed @@ -151,7 +182,7 @@ using NumericalEarth.DataWrangling: compute_native_date_range filepaths = unique(metadata_path(metadata)) Ta = download_dataset_with_fallback(filepaths; dataset_name="MultiYearJRA55 :temperature") do - JRA55FieldTimeSeries(:temperature, arch; dataset, start_date, end_date, backend, dir) + FieldTimeSeries(metadata, arch; time_indices_in_memory=10) end @test Second(end_date - start_date).value ≈ Ta.times[end] - Ta.times[1] @@ -160,5 +191,36 @@ using NumericalEarth.DataWrangling: compute_native_date_range @test Ta[t] isa Field end end + + @info "Testing MultiYearJRA55 single-window crossing year boundary on $A..." + + # Force a single in-memory window to straddle the 1958 → 1959 file + # boundary. Before the per-file `ftsn_loc` fix, the second file's + # iteration in `set!` would clobber the outer `ftsn` and write to the + # wrong slots; this regression test would then leave some in-memory + # slots untouched (zero-valued). + start_date_span = DateTime("1958-12-27T00:00:00") + end_date_span = DateTime("1959-01-05T00:00:00") + + mktempdir("./") do dir + native_dates = NumericalEarth.DataWrangling.all_dates(dataset, :temperature) + dates = compute_native_date_range(native_dates, start_date_span, end_date_span) + metadata = Metadata(:temperature; dataset, dates, dir) + filepaths = unique(metadata_path(metadata)) + + Ta_span = download_dataset_with_fallback(filepaths; + dataset_name="MultiYearJRA55 :temperature year-boundary window") do + # backend window of 80 holds the whole range in a single window + FieldTimeSeries(metadata, arch; time_indices_in_memory=80) + end + + # Every slot in the single in-memory window must carry valid + # (non-zero) atmospheric temperature data. + CUDA.@allowscalar begin + for t in eachindex(Ta_span.times) + @test maximum(abs, interior(Ta_span[t])) > 0 + end + end + end end end From 32ef1f1de4fa24eb619755b46944b8f649f284d9 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Fri, 17 Apr 2026 13:47:55 +0200 Subject: [PATCH 02/44] formatting changes --- src/DataWrangling/prefetching_backend.jl | 46 ++++-------------------- 1 file changed, 7 insertions(+), 39 deletions(-) diff --git a/src/DataWrangling/prefetching_backend.jl b/src/DataWrangling/prefetching_backend.jl index 6f1c864d8..edda30656 100644 --- a/src/DataWrangling/prefetching_backend.jl +++ b/src/DataWrangling/prefetching_backend.jl @@ -7,35 +7,14 @@ import Oceananigans.Fields: set! """ PrefetchingBackend{B<:DatasetBackend} <: AbstractInMemoryBackend{Int} -Wrapper around a `DatasetBackend` that hides the next window's I/O behind -the current window's compute by reading into a *buffer* `FieldTimeSeries` -on a background `Task`. The next call to `set!` either copies the -already-loaded buffer into the main FTS (hot path) or falls back to a -synchronous read (cold path), then schedules the prefetch for the window -after that. - -Fields ------- - -- `inner`: the wrapped `DatasetBackend`. Carries the `start`/`length` - window, metadata, inpainting setting — the cold-path read and the - prefetched read both dispatch through it. -- `task`: the in-flight prefetch `Task`, or `nothing`. -- `buffer_fts`: a clone of the main `FieldTimeSeries` whose `data` array - is the destination of the prefetch read. Constructed before the task is - spawned so that the spawn does not race on FTS construction. -- `next_start`: the absolute time index at which `buffer_fts` will start - once its prefetch task completes — compared against the requested - `start` on the next `set!` call to decide hot vs. cold path. +Wrapper around a `DatasetBackend` that hides the next window's I/O behind the current window's compute by reading into +a *buffer* `FieldTimeSeries` on a background `Task`. The next call to `set!` either copies the already-loaded buffer +into the main FTS (hot path) or falls back to a synchronous read (cold path), then schedules the prefetch for the window after that. """ -mutable struct PrefetchingBackend{B<:DatasetBackend} <: AbstractInMemoryBackend{Int} +mutable struct PrefetchingBackend{B<:DatasetBackend, F} <: AbstractInMemoryBackend{Int} inner :: B task :: Union{Task, Nothing} - buffer_fts :: Any # erased: a `FieldTimeSeries` whose concrete type - # depends on grid/loc/inpainting and is only set - # at the first `set!`. Erasing it keeps the - # backend's type stable across reloads so that - # Oceananigans' `new_backend` round-trip type-checks. + buffer_fts :: F next_start :: Int end @@ -51,26 +30,15 @@ function Base.getproperty(p::PrefetchingBackend, name::Symbol) end end -Base.length(p::PrefetchingBackend) = length(p.inner) +Base.length(p::PrefetchingBackend) = length(p.inner) +Base.summary(p::PrefetchingBackend) = string("PrefetchingBackend(", p.inner.start, ", ", p.inner.length, "; pending_prefetch=", !isnothing(getfield(p, :task)), ")") -Base.summary(p::PrefetchingBackend) = - string("PrefetchingBackend(", p.inner.start, ", ", p.inner.length, - "; pending_prefetch=", !isnothing(getfield(p, :task)), ")") - -# When Oceananigans rolls the in-memory window forward it calls -# `new_backend(b, start, length)`. Forward the call to the inner backend -# and preserve the prefetch state — the start passed in here is the new -# window's start, which `set!(fts::PrefetchingFTS)` will use to decide if -# the pending prefetch already covers it. new_backend(p::PrefetchingBackend, start, length) = PrefetchingBackend(new_backend(p.inner, start, length), getfield(p, :task), getfield(p, :buffer_fts), getfield(p, :next_start)) -# Dropping prefetch state on adapt — `Task` and the buffer FTS aren't -# meaningful on a different architecture, and they can't be cleanly -# serialised across a CPU↔GPU boundary. Adapt.adapt_structure(to, p::PrefetchingBackend) = Adapt.adapt(to, p.inner) const PrefetchingFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:PrefetchingBackend} From 9886fad775f712b34e69840d83af9d20226c4849 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Fri, 17 Apr 2026 16:11:11 +0200 Subject: [PATCH 03/44] fix code fragilities --- src/DataWrangling/JRA55/JRA55.jl | 2 +- .../JRA55/JRA55_field_time_series.jl | 76 +++++---- .../metadata_field_time_series.jl | 14 +- src/DataWrangling/prefetching_backend.jl | 150 ++++++++++-------- src/NumericalEarth.jl | 1 + test/test_checkpointer.jl | 58 +++++++ test/test_jra55.jl | 21 +++ 7 files changed, 230 insertions(+), 92 deletions(-) diff --git a/src/DataWrangling/JRA55/JRA55.jl b/src/DataWrangling/JRA55/JRA55.jl index 9308c665f..1428580b4 100644 --- a/src/DataWrangling/JRA55/JRA55.jl +++ b/src/DataWrangling/JRA55/JRA55.jl @@ -1,6 +1,6 @@ module JRA55 -export JRA55PrescribedAtmosphere, RepeatYearJRA55, MultiYearJRA55, JRA55NetCDFBackend +export JRA55PrescribedAtmosphere, RepeatYearJRA55, MultiYearJRA55, JRA55NetCDFBackend, JRA55FieldTimeSeries using Oceananigans using Oceananigans.Units diff --git a/src/DataWrangling/JRA55/JRA55_field_time_series.jl b/src/DataWrangling/JRA55/JRA55_field_time_series.jl index dc96e5ce9..db59a5a20 100644 --- a/src/DataWrangling/JRA55/JRA55_field_time_series.jl +++ b/src/DataWrangling/JRA55/JRA55_field_time_series.jl @@ -72,6 +72,27 @@ function compute_bounding_indices(longitude, latitude, grid, LX, LY, λc, φc) return i₁, i₂, j₁, j₂, TX end +# Migration shim — `JRA55FieldTimeSeries` was removed in favor of the +# generic `FieldTimeSeries(::Metadata)` path. The shim throws a clear +# error pointing at the new API; can be deleted once downstream callers +# (ClimaOcean.jl, user scripts) have migrated. +function JRA55FieldTimeSeries(args...; kwargs...) + error(""" + `JRA55FieldTimeSeries` was removed; JRA55 now uses the generic + dataset API. Migrate as: + + FieldTimeSeries(Metadata(:variable_name; dataset=RepeatYearJRA55(), + start_date, end_date, dir), + architecture; + time_indices_in_memory = N, + prefetch = false) + + The `InMemory()` backend is no longer supported for JRA55; pass + `time_indices_in_memory = length(metadata)` to keep the whole + series in memory. + """) +end + """ JRA55NetCDFBackend(length [, metadata]) JRA55NetCDFBackend(start, length, metadata) @@ -208,42 +229,42 @@ function set!(fts::JRA55NetCDFFTSRepeatYear, backend=fts.backend) return nothing end -# Tricky case: multiple files per variable -- one file per year -- -# we need to infer the file name from the metadata and split the data loading +# Multi-year case: one file per calendar year. Match each in-memory slot's +# date to the file's time axis by (Y, M, D, H, min) components rather than +# by seconds-since-start. JRA55 files use `DateTimeNoLeap` (365-day) +# internally; doing seconds arithmetic against a Gregorian +# `metadata.dates` would diverge by a multiple of 86400 after every leap +# day passed and silently drop the affected slots. Component matching +# sidesteps this entirely. Feb 29 of leap years has no entry in the file +# and is skipped here too. function set!(fts::JRA55NetCDFFTSMultipleYears, backend=fts.backend) - metadata = backend.metadata - name = dataset_variable_name(metadata) - start_date = first_date(metadata.dataset, metadata.name) + metadata = backend.metadata + name = dataset_variable_name(metadata) - ftsn = collect(time_indices(fts)) + ftsn = collect(time_indices(fts)) + slot_dates = metadata.dates[ftsn] - # Only open files that actually contain needed time indices - # (metadata.filename maps each time index to its yearly file) needed_files = unique(getfilename(metadata.filename, n) for n in ftsn) for file in needed_files path = joinpath(metadata.dir, file) - ds = Dataset(path) + ds = Dataset(path) - # This can be simplified once we start supporting a - # datetime `Clock` in Oceananigans file_dates = ds["time"][:] - file_indices = 1:length(file_dates) - file_times = zeros(length(file_dates)) - for (t, date) in enumerate(file_dates) - delta = date - start_date - delta = Second(delta).value - file_times[t] = delta - end - # Intersect the time indices with the file times - nn = findall(n -> file_times[n] ∈ fts.times[ftsn], file_indices) - ftsn_loc = findall(n -> fts.times[n] ∈ file_times[nn], ftsn) + nn = Int[] + ftsn_loc = Int[] + for (loc, slot_date) in enumerate(slot_dates) + file_idx = jra55_no_leap_file_index(file_dates, slot_date) + if !isnothing(file_idx) + push!(nn, file_idx) + push!(ftsn_loc, loc) + end + end if !isempty(nn) - # Nodes at the variable location λc = ds["lon"][:] φc = ds["lat"][:] LX, LY, LZ = location(fts) @@ -252,16 +273,17 @@ function set!(fts::JRA55NetCDFFTSMultipleYears, backend=fts.backend) if issorted(nn) data = ds[name][i₁:i₂, j₁:j₂, nn] else - # At the cyclical wrap (end of year 60 → start of year 1), - # file-local indices may be unsorted. DiskArrays requires - # sorted indices, so we load in two sorted chunks. - m = findfirst(n -> n == 1, nn) + # Defensive: per-file `nn` is normally sorted because we + # iterate `slot_dates` in `ftsn` order and a single file + # holds a single contiguous year, but DiskArrays requires + # sorted indices so we split at the wrap if it occurs. + m = findfirst(n -> n == 1, nn) n1 = nn[1:m-1] n2 = nn[m:end] data1 = ds[name][i₁:i₂, j₁:j₂, n1] data2 = ds[name][i₁:i₂, j₁:j₂, n2] - data = cat(data1, data2, dims=3) + data = cat(data1, data2, dims=3) end close(ds) diff --git a/src/DataWrangling/metadata_field_time_series.jl b/src/DataWrangling/metadata_field_time_series.jl index 7a8bc084e..bc3de290d 100644 --- a/src/DataWrangling/metadata_field_time_series.jl +++ b/src/DataWrangling/metadata_field_time_series.jl @@ -52,12 +52,22 @@ function FieldTimeSeries(metadata::Metadata, grid::AbstractGrid; download_dataset(metadata) inpainting isa Int && (inpainting = NearestNeighborInpainting(inpainting)) - backend = DatasetBackend(time_indices_in_memory, metadata; on_native_grid, inpainting, cache_inpainted_data) - prefetch && (backend = PrefetchingBackend(backend)) + inner = DatasetBackend(time_indices_in_memory, metadata; on_native_grid, inpainting, cache_inpainted_data) times = native_times(metadata) loc = LX, LY, LZ = location(metadata) boundary_conditions = FieldBoundaryConditions(grid, instantiate.(loc)) + + if prefetch + Threads.nthreads() < 2 && @warn "prefetch=true is a no-op with JULIA_NUM_THREADS=$(Threads.nthreads()); start Julia with ≥ 2 threads." + # Buffer FTS is allocated once and reused per reload (see prefetching_backend.jl). + buffer_inner = new_backend(inner, 1, time_indices_in_memory) + buffer_fts = FieldTimeSeries{LX, LY, LZ}(grid, times; backend=buffer_inner, time_indexing, boundary_conditions) + backend = PrefetchingBackend(inner, buffer_fts) + else + backend = inner + end + fts = FieldTimeSeries{LX, LY, LZ}(grid, times; backend, time_indexing, boundary_conditions) set!(fts) diff --git a/src/DataWrangling/prefetching_backend.jl b/src/DataWrangling/prefetching_backend.jl index edda30656..f9988a99e 100644 --- a/src/DataWrangling/prefetching_backend.jl +++ b/src/DataWrangling/prefetching_backend.jl @@ -1,97 +1,123 @@ -using Oceananigans.OutputReaders: AbstractInMemoryBackend, FlavorOfFTS, FieldTimeSeries -using Oceananigans.Fields: location, instantiated_location +# Asynchronous prefetch wrapper around `DatasetBackend`. Hides the next +# sliding-window's I/O behind the current window's compute by reading into +# a buffer `FieldTimeSeries` on a `Threads.@spawn`-ed task. Every `set!` +# either copies from the prefetched buffer (hot) or loads synchronously +# (cold), then schedules the next window's read. The buffer is allocated +# once at FTS construction and reused for every reload — zero allocation +# per `set!`. +# +# Race invariant: between the spawn at the end of one `set!` and the +# `wait` at the start of the next, the worker is mutating +# `buffer_fts.data`. No code outside `set!(::PrefetchingFTS)` may touch +# `buffer_fts` in that window. Two enforcement points: `:buffer_fts` is +# not forwarded by `getproperty`, and `Adapt.adapt_structure` returns +# only the inner backend. `wait_for_prefetch!` is the safe drain hook. +# +# Requires `JULIA_NUM_THREADS ≥ 2` to actually overlap; one thread makes +# the spawn cooperatively-scheduled and the optimisation a no-op. + +using Oceananigans.OutputReaders: AbstractInMemoryBackend, FlavorOfFTS, FieldTimeSeries, time_index +using Oceananigans.Fields: location import Oceananigans.OutputReaders: new_backend import Oceananigans.Fields: set! -""" - PrefetchingBackend{B<:DatasetBackend} <: AbstractInMemoryBackend{Int} - -Wrapper around a `DatasetBackend` that hides the next window's I/O behind the current window's compute by reading into -a *buffer* `FieldTimeSeries` on a background `Task`. The next call to `set!` either copies the already-loaded buffer -into the main FTS (hot path) or falls back to a synchronous read (cold path), then schedules the prefetch for the window after that. -""" -mutable struct PrefetchingBackend{B<:DatasetBackend, F} <: AbstractInMemoryBackend{Int} +mutable struct PrefetchingBackend{B<:DatasetBackend, F<:FieldTimeSeries} <: AbstractInMemoryBackend{Int} inner :: B - task :: Union{Task, Nothing} + pending :: Union{Task, Nothing} buffer_fts :: F next_start :: Int end -PrefetchingBackend(inner::DatasetBackend) = PrefetchingBackend{typeof(inner)}(inner, nothing, nothing, 0) +PrefetchingBackend(inner::DatasetBackend, buffer_fts::FieldTimeSeries) = + PrefetchingBackend{typeof(inner), typeof(buffer_fts)}(inner, nothing, buffer_fts, 0) -# Forward properties so `fts.backend.start`, `fts.backend.metadata`, etc. -# continue to work for downstream code that pokes at the inner backend. +# `:buffer_fts` deliberately omitted — see race invariant in preamble. function Base.getproperty(p::PrefetchingBackend, name::Symbol) - if name in (:inner, :task, :buffer_fts, :next_start) + if name in (:inner, :pending, :next_start) return getfield(p, name) else return getproperty(getfield(p, :inner), name) end end -Base.length(p::PrefetchingBackend) = length(p.inner) -Base.summary(p::PrefetchingBackend) = string("PrefetchingBackend(", p.inner.start, ", ", p.inner.length, "; pending_prefetch=", !isnothing(getfield(p, :task)), ")") +Base.length(p::PrefetchingBackend) = length(p.inner) + +Base.summary(p::PrefetchingBackend) = + string("PrefetchingBackend(", p.inner.start, ", ", p.inner.length, + "; pending=", !isnothing(getfield(p, :pending)), ")") -new_backend(p::PrefetchingBackend, start, length) = - PrefetchingBackend(new_backend(p.inner, start, length), - getfield(p, :task), - getfield(p, :buffer_fts), - getfield(p, :next_start)) +# Mutate in place rather than constructing a fresh wrapper — keeps the +# `pending`/`buffer_fts`/`next_start` mutable state in exactly one object. +function new_backend(p::PrefetchingBackend, start, length) + setfield!(p, :inner, new_backend(getfield(p, :inner), start, length)) + return p +end -Adapt.adapt_structure(to, p::PrefetchingBackend) = Adapt.adapt(to, p.inner) +Adapt.adapt_structure(to, p::PrefetchingBackend) = Adapt.adapt(to, getfield(p, :inner)) const PrefetchingFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:PrefetchingBackend} -# Build a clone FTS whose `.backend` is the supplied (inner) DatasetBackend -# and whose `.data` is freshly allocated. The clone shares grid, times, -# location, time_indexing and boundary conditions with `fts`, so writing -# its interior is layout-compatible with the main FTS. -function buffer_field_time_series(fts, inner_backend) - LX, LY, LZ = location(fts) - return FieldTimeSeries{LX, LY, LZ}(fts.grid, fts.times; - backend = inner_backend, - time_indexing = fts.time_indexing, - boundary_conditions = fts.boundary_conditions) +""" + wait_for_prefetch!(backend::PrefetchingBackend) + +Block until the in-flight prefetch task completes, then clear it. +Required before any code that needs a consistent view of the buffer FTS +(checkpointing, JLD2 serialisation, manual `getfield(..., :buffer_fts)`). +""" +function wait_for_prefetch!(p::PrefetchingBackend) + pending = getfield(p, :pending) + if !isnothing(pending) + wait(pending) + setfield!(p, :pending, nothing) + end + return nothing end -function set!(fts::PrefetchingFTS, backend=fts.backend) - needed_start = getfield(backend, :inner).start +function set!(fts::PrefetchingFTS, backend::PrefetchingBackend = fts.backend) + needed_start = getfield(backend, :inner).start + pending = getfield(backend, :pending) + pending_start = getfield(backend, :next_start) + buffer_fts = getfield(backend, :buffer_fts) - pending_task = getfield(backend, :task) - pending_fts = getfield(backend, :buffer_fts) - pending_start = getfield(backend, :next_start) + # Cleared up-front so a failed prefetch isn't re-thrown on every later set!. + setfield!(backend, :pending, nothing) - if !isnothing(pending_task) && pending_start == needed_start && !isnothing(pending_fts) - # Hot path — wait on the prefetch (likely already done) and copy - # the buffer FTS into the main FTS. - wait(pending_task) - copyto!(parent(fts.data), parent(pending_fts.data)) + if !isnothing(pending) && pending_start == needed_start + wait(pending) else - # Cold path — drain any stale prefetch, then synchronously load - # via a one-off buffer FTS. We can't dispatch the JRA55-specific - # `set!` directly on `fts` because `fts.backend` is the - # PrefetchingBackend, not the inner one; so we round-trip through - # a clone FTS whose `.backend` is the inner. - if !isnothing(pending_task) - wait(pending_task) - end - - cold_fts = buffer_field_time_series(fts, getfield(backend, :inner)) - set!(cold_fts) - copyto!(parent(fts.data), parent(cold_fts.data)) + !isnothing(pending) && wait(pending) + Nm = length(getfield(backend, :inner)) + buffer_fts.backend = new_backend(buffer_fts.backend, needed_start, Nm) + set!(buffer_fts) end - # Kick off the next prefetch (cyclically wrapping at the end of times). + copyto!(parent(fts.data), parent(buffer_fts.data)) + + # Time-indexing-aware next-window prediction: `time_index` wraps + # via mod1 for Cyclical and clamps to Nt for Linear/Clamp. Nm = length(getfield(backend, :inner)) Nt = length(fts.times) - next_start = mod1(needed_start + Nm, Nt) - next_inner = new_backend(getfield(backend, :inner), next_start, Nm) - next_fts = buffer_field_time_series(fts, next_inner) + new_next = time_index(buffer_fts.backend, fts.time_indexing, Nt, Nm + 1) + + if new_next == needed_start + # Linear/Clamp at end-of-data: window can't advance, no prefetch. + setfield!(backend, :next_start, 0) + return nothing + end - setfield!(backend, :next_start, next_start) - setfield!(backend, :buffer_fts, next_fts) - setfield!(backend, :task, Threads.@spawn set!(next_fts)) + buffer_fts.backend = new_backend(buffer_fts.backend, new_next, Nm) + setfield!(backend, :next_start, new_next) + # Worker-side @error logs context (spawn site is gone by `wait` time); rethrow preserves the original. + setfield!(backend, :pending, Threads.@spawn begin + try + set!(buffer_fts) + catch e + m = buffer_fts.backend.metadata + @error "PrefetchingBackend: prefetch task failed" dataset=typeof(m.dataset) variable=m.name window=(new_next, new_next + Nm - 1) exception=(e, catch_backtrace()) + rethrow() + end + end) return nothing end diff --git a/src/NumericalEarth.jl b/src/NumericalEarth.jl index ce0af8229..1421eadc5 100644 --- a/src/NumericalEarth.jl +++ b/src/NumericalEarth.jl @@ -43,6 +43,7 @@ export first_date, last_date, all_dates, + JRA55FieldTimeSeries, LinearlyTaperedPolarMask, DatasetRestoring, ocean_simulation, diff --git a/test/test_checkpointer.jl b/test/test_checkpointer.jl index b52855289..71ee4ada8 100644 --- a/test/test_checkpointer.jl +++ b/test/test_checkpointer.jl @@ -110,5 +110,63 @@ using Oceananigans.OutputWriters: Checkpointer # Cleanup rm.(glob("$(prefix)_iteration*.jld2"), force=true) + + # ---- Same workflow but with prefetch=true on the JRA55 atmosphere. + # Verifies that (1) the JLD2 checkpointer doesn't choke on the + # PrefetchingBackend's mutable state and (2) the restored model + # produces the same data as the reference (i.e. the cold-path + # fallback on restart correctly re-prefetches). With nthreads()=1 + # the prefetch becomes a no-op, which still exercises the + # serialisation and cold-fallback logic. + @info "Testing EarthSystemModel checkpointing with prefetch=true on $A" + + function make_coupled_model_prefetch(grid) + @inline hi(λ, φ) = φ > 70 || φ < -70 + + ocean = ocean_simulation(grid, closure=nothing) + set!(ocean.model, T=20, S=35, u=0.01, v=-0.005) + sea_ice = sea_ice_simulation(grid, ocean) + set!(sea_ice.model, h=hi, ℵ=hi) + + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4, prefetch=true) + + return OceanSeaIceModel(ocean, sea_ice; atmosphere) + end + + # Reference run with prefetch + model = make_coupled_model_prefetch(grid) + run!(Simulation(model, Δt=60, stop_iteration=3)) + run!(Simulation(model, Δt=60, stop_iteration=6)) + + ref_T = Array(interior(model.ocean.model.tracers.T)) + ref_h = Array(interior(model.sea_ice.model.ice_thickness)) + ref_time = model.clock.time + + # Checkpointed run with prefetch + model = make_coupled_model_prefetch(grid) + simulation = Simulation(model, Δt=60, stop_iteration=3) + prefix_pf = "osm_checkpointer_prefetch_test_$(typeof(arch))" + simulation.output_writers[:checkpointer] = Checkpointer(simulation.model; + schedule = IterationInterval(3), + prefix = prefix_pf) + run!(simulation) + @test isfile("$(prefix_pf)_iteration3.jld2") # JLD2 didn't choke on prefetch state + + model = make_coupled_model_prefetch(grid) + simulation = Simulation(model, Δt=60, stop_iteration=6) + simulation.output_writers[:checkpointer] = Checkpointer(model; + schedule = IterationInterval(3), + prefix = prefix_pf) + set!(simulation; checkpoint=:latest) + set!(simulation; iteration=3) + run!(simulation) + + T = Array(interior(model.ocean.model.tracers.T)) + h = Array(interior(model.sea_ice.model.ice_thickness)) + @test T ≈ ref_T rtol=1e-13 + @test h ≈ ref_h rtol=1e-13 + @test model.clock.time == ref_time + + rm.(glob("$(prefix_pf)_iteration*.jld2"), force=true) end end diff --git a/test/test_jra55.jl b/test/test_jra55.jl index ea9c77f3a..e3fe40e2e 100644 --- a/test/test_jra55.jl +++ b/test/test_jra55.jl @@ -92,6 +92,27 @@ using NumericalEarth.DataWrangling: compute_native_date_range end end + @info "Testing Field(::JRA55Metadatum) on $A..." + # Locks in: position-based RepeatYear file_index lookup, halo + # periodic wrap, and agreement between the per-Metadatum Field + # path and the chunked-file FTS path at the same time index. + ds_var = :downwelling_shortwave_radiation + all_jra55_dates = NumericalEarth.DataWrangling.all_dates(JRA55.RepeatYearJRA55(), ds_var) + + md_first = Metadatum(ds_var; dataset=JRA55.RepeatYearJRA55()) + f_first = Field(md_first, arch) + @test f_first isa Field + @test size(f_first) == (640, 320, 1) + CUDA.@allowscalar @test f_first[1, 1, 1] == 430.98105f0 + CUDA.@allowscalar @test view(f_first.data, 1, :, 1) == view(f_first.data, 641, :, 1) + + md_mid = Metadatum(ds_var; dataset=JRA55.RepeatYearJRA55(), date=all_jra55_dates[100]) + f_mid = Field(md_mid, arch) + # Same time index loaded via the chunked-file FTS path → must agree + fts100 = FieldTimeSeries(Metadata(ds_var; dataset=JRA55.RepeatYearJRA55(), end_date=all_jra55_dates[100]), arch; time_indices_in_memory=100) + CUDA.@allowscalar @test f_mid[1, 1, 1] == fts100[1, 1, 1, 100] + CUDA.@allowscalar @test f_mid[640, 1, 1] == fts100[640, 1, 1, 100] + @info "Testing interpolate_field_time_series! on $A..." name = :downwelling_shortwave_radiation From 8c2f5d5b30c5565007c25f383d8a38d9d847381d Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Fri, 17 Apr 2026 18:12:47 +0200 Subject: [PATCH 04/44] some changes --- src/DataWrangling/ECCO/ECCO_atmosphere.jl | 30 +++++------ .../metadata_field_time_series.jl | 6 ++- src/DataWrangling/prefetching_backend.jl | 54 ++++++------------- src/DataWrangling/restoring.jl | 8 +-- 4 files changed, 34 insertions(+), 64 deletions(-) diff --git a/src/DataWrangling/ECCO/ECCO_atmosphere.jl b/src/DataWrangling/ECCO/ECCO_atmosphere.jl index 085e66a4e..7858dbfcb 100644 --- a/src/DataWrangling/ECCO/ECCO_atmosphere.jl +++ b/src/DataWrangling/ECCO/ECCO_atmosphere.jl @@ -27,30 +27,24 @@ function ECCOPrescribedAtmosphere(architecture = CPU(), FT = Float32; end_date = last_date(dataset, :air_temperature), dir = default_download_directory(dataset), time_indexing = Cyclical(), + prefetch = false, time_indices_in_memory = 10, surface_layer_height = 2, # meters other_kw...) - ua_meta = Metadata(:eastward_wind; dataset, start_date, end_date, dir) - va_meta = Metadata(:northward_wind; dataset, start_date, end_date, dir) - Ta_meta = Metadata(:air_temperature; dataset, start_date, end_date, dir) - qa_meta = Metadata(:air_specific_humidity; dataset, start_date, end_date, dir) - pa_meta = Metadata(:sea_level_pressure; dataset, start_date, end_date, dir) - ℐꜜˡʷ_meta = Metadata(:downwelling_longwave; dataset, start_date, end_date, dir) - ℐꜜˢʷ_meta = Metadata(:downwelling_shortwave; dataset, start_date, end_date, dir) - Fr_meta = Metadata(:rain_freshwater_flux; dataset, start_date, end_date, dir) - - kw = (; time_indices_in_memory, time_indexing) + kw = (; time_indexing, time_indices_in_memory, prefetch) kw = merge(kw, other_kw) - ua = FieldTimeSeries(ua_meta, architecture; kw...) - va = FieldTimeSeries(va_meta, architecture; kw...) - Ta = FieldTimeSeries(Ta_meta, architecture; kw...) - qa = FieldTimeSeries(qa_meta, architecture; kw...) - pa = FieldTimeSeries(pa_meta, architecture; kw...) - ℐꜜˡʷ = FieldTimeSeries(ℐꜜˡʷ_meta, architecture; kw...) - ℐꜜˢʷ = FieldTimeSeries(ℐꜜˢʷ_meta, architecture; kw...) - Fr = FieldTimeSeries(Fr_meta, architecture; kw...) + ecco_fts(name) = FieldTimeSeries(Metadata(name; dataset, start_date, end_date, dir), architecture; kw...) + + ua = ecco_fts(:eastward_wind) + va = ecco_fts(:northward_wind) + Ta = ecco_fts(:air_temperature) + qa = ecco_fts(:air_specific_humidity) + pa = ecco_fts(:sea_level_pressure) + ℐꜜˡʷ = ecco_fts(:downwelling_longwave) + ℐꜜˢʷ = ecco_fts(:downwelling_shortwave) + Fr = ecco_fts(:rain_freshwater_flux) auxiliary_freshwater_flux = nothing freshwater_flux = (; rain = Fr) diff --git a/src/DataWrangling/metadata_field_time_series.jl b/src/DataWrangling/metadata_field_time_series.jl index bc3de290d..61eb57c6a 100644 --- a/src/DataWrangling/metadata_field_time_series.jl +++ b/src/DataWrangling/metadata_field_time_series.jl @@ -1,4 +1,4 @@ -using Oceananigans.Architectures: AbstractArchitecture +using Oceananigans.Architectures: AbstractArchitecture, architecture using Oceananigans.Grids: AbstractGrid using Oceananigans.Fields: interpolate! @@ -48,9 +48,11 @@ function FieldTimeSeries(metadata::Metadata, grid::AbstractGrid; cache_inpainted_data = true, prefetch = false) - # Make sure all the required individual files are downloaded download_dataset(metadata) + # Detect "the user's grid IS the native grid" structurally + on_native_grid = grid == native_grid(metadata, architecture(grid)) + inpainting isa Int && (inpainting = NearestNeighborInpainting(inpainting)) inner = DatasetBackend(time_indices_in_memory, metadata; on_native_grid, inpainting, cache_inpainted_data) diff --git a/src/DataWrangling/prefetching_backend.jl b/src/DataWrangling/prefetching_backend.jl index f9988a99e..717170cb2 100644 --- a/src/DataWrangling/prefetching_backend.jl +++ b/src/DataWrangling/prefetching_backend.jl @@ -1,20 +1,14 @@ -# Asynchronous prefetch wrapper around `DatasetBackend`. Hides the next -# sliding-window's I/O behind the current window's compute by reading into -# a buffer `FieldTimeSeries` on a `Threads.@spawn`-ed task. Every `set!` -# either copies from the prefetched buffer (hot) or loads synchronously -# (cold), then schedules the next window's read. The buffer is allocated -# once at FTS construction and reused for every reload — zero allocation -# per `set!`. +# Asynchronous prefetch wrapper around `DatasetBackend`. Hides the next sliding-window's I/O behind the current window's compute +# by reading into a buffer `FieldTimeSeries` on a `Threads.@spawn`-ed task. Every `set!` either copies from the prefetched buffer +# (hot) or loads synchronously (cold), then schedules the next window's read. The buffer is allocated once at FTS construction and +# reused for every reload — zero allocation per `set!`. # -# Race invariant: between the spawn at the end of one `set!` and the -# `wait` at the start of the next, the worker is mutating -# `buffer_fts.data`. No code outside `set!(::PrefetchingFTS)` may touch -# `buffer_fts` in that window. Two enforcement points: `:buffer_fts` is -# not forwarded by `getproperty`, and `Adapt.adapt_structure` returns -# only the inner backend. `wait_for_prefetch!` is the safe drain hook. +# Race invariant: between the spawn at the end of one `set!` and the `wait` at the start of the next, the worker is mutating +# `buffer_fts.data`. No code outside `set!(::PrefetchingFTS)` may touch `buffer_fts` in that window. +# Two enforcement points: `:buffer_fts` is not forwarded by `getproperty`, and `Adapt.adapt_structure` returns +# only the inner backend. # -# Requires `JULIA_NUM_THREADS ≥ 2` to actually overlap; one thread makes -# the spawn cooperatively-scheduled and the optimisation a no-op. +# Requires `JULIA_NUM_THREADS ≥ 2` to actually overlap; one thread makes the spawn cooperatively-scheduled and the optimisation a no-op. using Oceananigans.OutputReaders: AbstractInMemoryBackend, FlavorOfFTS, FieldTimeSeries, time_index using Oceananigans.Fields: location @@ -29,23 +23,23 @@ mutable struct PrefetchingBackend{B<:DatasetBackend, F<:FieldTimeSeries} <: Abst next_start :: Int end -PrefetchingBackend(inner::DatasetBackend, buffer_fts::FieldTimeSeries) = - PrefetchingBackend{typeof(inner), typeof(buffer_fts)}(inner, nothing, buffer_fts, 0) +PrefetchingBackend(inner::DatasetBackend, buffer_fts::FieldTimeSeries) = PrefetchingBackend{typeof(inner), typeof(buffer_fts)}(inner, nothing, buffer_fts, 0) -# `:buffer_fts` deliberately omitted — see race invariant in preamble. +# `:buffer_fts` deliberately warned upon — see race invariant in preamble. function Base.getproperty(p::PrefetchingBackend, name::Symbol) if name in (:inner, :pending, :next_start) return getfield(p, name) + elseif name == :buffer_fts + @warn "`buffer_fts` is an inner auxiliary field touched on an hot-loop separate task. " * + "Mutating it manually might lead to undefined behavior. It is recommended not modifying it." + return getfield(p, name) else return getproperty(getfield(p, :inner), name) end end Base.length(p::PrefetchingBackend) = length(p.inner) - -Base.summary(p::PrefetchingBackend) = - string("PrefetchingBackend(", p.inner.start, ", ", p.inner.length, - "; pending=", !isnothing(getfield(p, :pending)), ")") +Base.summary(p::PrefetchingBackend) = string("PrefetchingBackend(", p.inner.start, ", ", p.inner.length, "; pending=", !isnothing(getfield(p, :pending)), ")") # Mutate in place rather than constructing a fresh wrapper — keeps the # `pending`/`buffer_fts`/`next_start` mutable state in exactly one object. @@ -58,22 +52,6 @@ Adapt.adapt_structure(to, p::PrefetchingBackend) = Adapt.adapt(to, getfield(p, : const PrefetchingFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:PrefetchingBackend} -""" - wait_for_prefetch!(backend::PrefetchingBackend) - -Block until the in-flight prefetch task completes, then clear it. -Required before any code that needs a consistent view of the buffer FTS -(checkpointing, JLD2 serialisation, manual `getfield(..., :buffer_fts)`). -""" -function wait_for_prefetch!(p::PrefetchingBackend) - pending = getfield(p, :pending) - if !isnothing(pending) - wait(pending) - setfield!(p, :pending, nothing) - end - return nothing -end - function set!(fts::PrefetchingFTS, backend::PrefetchingBackend = fts.backend) needed_start = getfield(backend, :inner).start pending = getfield(backend, :pending) diff --git a/src/DataWrangling/restoring.jl b/src/DataWrangling/restoring.jl index fe1ea996f..1aaa59af8 100644 --- a/src/DataWrangling/restoring.jl +++ b/src/DataWrangling/restoring.jl @@ -188,12 +188,8 @@ Keyword Arguments - `cache_inpainted_data`: If `true`, the data is cached to disk after inpainting for later retrieving. Default: `true`. -- `prefetch`: If `true`, the next sliding window is loaded asynchronously - on a background thread (`Threads.@spawn`) so the I/O cost of - the next reload overlaps the current window's compute. The - hot-path reload becomes a memory copy from the prefetched - buffer; on a cache miss (e.g. checkpointer restart) the cold - path falls back to a synchronous read. Default: `false`. +- `prefetch`: If `true`, hide the next reload's I/O behind compute via a background `Threads.@spawn`. + Intended for long-lived FTSes; short-lived ones leak one prefetch task. Default: `false`. """ function DatasetRestoring(metadata::Metadata, arch_or_grid = CPU(); From fde7baf5d438b10da24a3b83a1e841e4af35a11c Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Mon, 20 Apr 2026 14:33:50 +0200 Subject: [PATCH 05/44] fix all tests and examples --- examples/generate_surface_fluxes.jl | 4 ++-- examples/meridional_heat_transport_ecco.jl | 2 +- examples/near_global_ocean_simulation.jl | 4 ++-- examples/one_degree_simulation.jl | 2 +- examples/veros_ocean_forced_simulation.jl | 2 +- experiments/arctic_simulation.jl | 2 +- .../earth_system_coupled_simulation.jl | 2 +- experiments/flux_climatology/flux_climatology.jl | 2 +- experiments/one_degree_simulation/debug_tides.jl | 3 +-- .../one_degree_simulation/generate_tidal_forcing.jl | 3 +-- .../one_degree_simulation/one_degree_simulation.jl | 2 +- test/runtests.jl | 4 ++-- test/test_checkpointer.jl | 3 +-- test/test_ocean_only_model.jl | 6 ++---- test/test_ocean_sea_ice_model.jl | 3 +-- test/test_sea_ice_ocean_heat_fluxes.jl | 12 ++++-------- 16 files changed, 23 insertions(+), 33 deletions(-) diff --git a/examples/generate_surface_fluxes.jl b/examples/generate_surface_fluxes.jl index 3043bef0f..3c7ac17b0 100644 --- a/examples/generate_surface_fluxes.jl +++ b/examples/generate_surface_fluxes.jl @@ -45,9 +45,9 @@ save("ECCO_continents.png", fig) # - downwelling longwave radiation # # We load in memory only the first two time indices, corresponding to January 1st -# (at 00:00 AM and 03:00 AM), by using `JRA55NetCDFBackend(2)`. +# (at 00:00 AM and 03:00 AM), by using `time_indices_in_memory = 2`. -atmosphere = JRA55PrescribedAtmosphere(; backend = JRA55NetCDFBackend(2)) +atmosphere = JRA55PrescribedAtmosphere(; time_indices_in_memory = 2) ocean = ocean_simulation(grid, closure=nothing) # Now that we have an atmosphere and ocean, we `set!` the ocean temperature and salinity diff --git a/examples/meridional_heat_transport_ecco.jl b/examples/meridional_heat_transport_ecco.jl index 58b3e7849..4006dae1b 100755 --- a/examples/meridional_heat_transport_ecco.jl +++ b/examples/meridional_heat_transport_ecco.jl @@ -43,7 +43,7 @@ set!(ocean.model, T=ecco_temperature, S=ecco_salinity) set!(sea_ice.model, h=ecco_sea_ice_thickness, ℵ=ecco_sea_ice_concentration) radiation = Radiation(arch) -atmosphere = JRA55PrescribedAtmosphere(arch; backend=JRA55NetCDFBackend(80), +atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory = 80, include_rivers_and_icebergs = false) esm = OceanSeaIceModel(ocean, sea_ice; atmosphere, radiation) diff --git a/examples/near_global_ocean_simulation.jl b/examples/near_global_ocean_simulation.jl index 1cc0c065f..14f5ce91a 100644 --- a/examples/near_global_ocean_simulation.jl +++ b/examples/near_global_ocean_simulation.jl @@ -102,9 +102,9 @@ radiation = Radiation(arch) # The JRA55 dataset provides atmospheric data such as temperature, humidity, and winds # to calculate turbulent fluxes using bulk formulae, see [`InterfaceComputations`](@ref NumericalEarth.EarthSystemModels.InterfaceComputations). # The number of snapshots that are loaded into memory is determined by -# the `backend`. Here, we load 41 snapshots at a time into memory. +# `time_indices_in_memory`. Here, we load 41 snapshots at a time into memory. -atmosphere = JRA55PrescribedAtmosphere(arch; backend = JRA55NetCDFBackend(41), +atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory = 41, include_rivers_and_icebergs = false) # ## The coupled simulation diff --git a/examples/one_degree_simulation.jl b/examples/one_degree_simulation.jl index 60a438562..cb792ff03 100644 --- a/examples/one_degree_simulation.jl +++ b/examples/one_degree_simulation.jl @@ -96,7 +96,7 @@ set!(sea_ice.model, h=ecco_sea_ice_thickness, ℵ=ecco_sea_ice_concentration) # We force the simulation with a JRA55-do atmospheric reanalysis. radiation = Radiation(arch) -atmosphere = JRA55PrescribedAtmosphere(arch; backend=JRA55NetCDFBackend(80), +atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory = 80, include_rivers_and_icebergs = false) # ### Coupled simulation diff --git a/examples/veros_ocean_forced_simulation.jl b/examples/veros_ocean_forced_simulation.jl index bb95573f5..65db31b33 100644 --- a/examples/veros_ocean_forced_simulation.jl +++ b/examples/veros_ocean_forced_simulation.jl @@ -68,7 +68,7 @@ ocean.set_forcing = set_forcing_tke_only # This includes 2-meter wind velocity, temperature, humidity, downwelling longwave and shortwave # radiation, as well as freshwater fluxes. -atmos = JRA55PrescribedAtmosphere(; backend = JRA55NetCDFBackend(10)) +atmos = JRA55PrescribedAtmosphere(; time_indices_in_memory = 10) # The coupled ocean--atmosphere model. # We use the default radiation model and we do not couple an ice model for simplicity. diff --git a/experiments/arctic_simulation.jl b/experiments/arctic_simulation.jl index afabdf826..a602fe3b8 100644 --- a/experiments/arctic_simulation.jl +++ b/experiments/arctic_simulation.jl @@ -89,7 +89,7 @@ set!(sea_ice.model, h=Metadatum(:sea_ice_thickness; dataset), ##### A Prescribed Atmosphere model ##### -atmosphere = JRA55PrescribedAtmosphere(arch; backend=JRA55NetCDFBackend(40)) +atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=40) radiation = Radiation() ##### diff --git a/experiments/coupled_simulation/earth_system_coupled_simulation.jl b/experiments/coupled_simulation/earth_system_coupled_simulation.jl index 334ebce91..975602b65 100644 --- a/experiments/coupled_simulation/earth_system_coupled_simulation.jl +++ b/experiments/coupled_simulation/earth_system_coupled_simulation.jl @@ -79,7 +79,7 @@ salinity = ECCOMetadata(:salinity; dir="./") ice_thickness = ECCOMetadata(:sea_ice_thickness; dir="./") ice_concentration = ECCOMetadata(:sea_ice_concentration; dir="./") -atmosphere = JRA55PrescribedAtmosphere(arch, backend=JRA55NetCDFBackend(20)) +atmosphere = JRA55PrescribedAtmosphere(arch, time_indices_in_memory=20) radiation = Radiation(ocean_albedo = LatitudeDependentAlbedo(), sea_ice_albedo=0.6) set!(ocean.model, T=temperature, S=salinity) diff --git a/experiments/flux_climatology/flux_climatology.jl b/experiments/flux_climatology/flux_climatology.jl index 20ad23d60..c041b52e9 100644 --- a/experiments/flux_climatology/flux_climatology.jl +++ b/experiments/flux_climatology/flux_climatology.jl @@ -191,7 +191,7 @@ heat_capacity(ocean::Simulation{<:PrescribedOcean}) = 3995.6 ##### A prescribed atmosphere... ##### -atmosphere = JRA55PrescribedAtmosphere(arch; backend=JRA55NetCDFBackend(1000)) +atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=1000) ##### ##### A prescribed earth... diff --git a/experiments/one_degree_simulation/debug_tides.jl b/experiments/one_degree_simulation/debug_tides.jl index a06e975bf..8c7c544ba 100644 --- a/experiments/one_degree_simulation/debug_tides.jl +++ b/experiments/one_degree_simulation/debug_tides.jl @@ -8,9 +8,8 @@ using Statistics import SPICE -backend = JRA55NetCDFBackend(41) ρᵒᶜ = 1020 -pa = JRA55_field_time_series(:sea_level_pressure; backend) +pa = JRA55_field_time_series(:sea_level_pressure; time_indices_in_memory=41) grid = pa.grid # dt = 30 * 60 # minutes diff --git a/experiments/one_degree_simulation/generate_tidal_forcing.jl b/experiments/one_degree_simulation/generate_tidal_forcing.jl index 6bd301110..ff76cb277 100644 --- a/experiments/one_degree_simulation/generate_tidal_forcing.jl +++ b/experiments/one_degree_simulation/generate_tidal_forcing.jl @@ -6,8 +6,7 @@ using Dates using GLMakie import SPICE -backend = JRA55NetCDFBackend(41) -pa = JRA55_field_time_series(:sea_level_pressure; backend) +pa = JRA55_field_time_series(:sea_level_pressure; time_indices_in_memory=41) pa_with_tides = FieldTimeSeries{Center, Center, Nothing}(pa.grid, pa.times, path = "sea_level_pressure_plus_tides.jld2", diff --git a/experiments/one_degree_simulation/one_degree_simulation.jl b/experiments/one_degree_simulation/one_degree_simulation.jl index e48bc4d56..f2deda0a2 100644 --- a/experiments/one_degree_simulation/one_degree_simulation.jl +++ b/experiments/one_degree_simulation/one_degree_simulation.jl @@ -46,7 +46,7 @@ set!(ocean.model, T=ECCOMetadata(:temperature; dates=first(dates)), S=ECCOMetadata(:salinity; dates=first(dates))) radiation = Radiation(arch) -atmosphere = JRA55PrescribedAtmosphere(arch; backend=JRA55NetCDFBackend(41)) +atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=41) coupled_model = OceanOnlyModel(ocean; atmosphere, radiation) simulation = Simulation(coupled_model; Δt=10minutes, stop_iteration=100) diff --git a/test/runtests.jl b/test/runtests.jl index 39e6985b0..341fcbdc0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -69,7 +69,7 @@ function __init__() ##### try - atmosphere = JRA55PrescribedAtmosphere(backend=JRA55NetCDFBackend(2)) + atmosphere = JRA55PrescribedAtmosphere(time_indices_in_memory=2) catch e @warn "Original JRA55 download failed, trying NumericalEarthArtifacts fallback..." exception=(e, catch_backtrace()) emit_ci_warning("Broken JRA55 download", "Original source failed during init") @@ -77,7 +77,7 @@ function __init__() datum = Metadatum(name; dataset=JRA55.RepeatYearJRA55()) download_from_artifacts(metadata_path(datum)) end - atmosphere = JRA55PrescribedAtmosphere(backend=JRA55NetCDFBackend(2)) + atmosphere = JRA55PrescribedAtmosphere(time_indices_in_memory=2) end ##### diff --git a/test/test_checkpointer.jl b/test/test_checkpointer.jl index 71ee4ada8..236800ae7 100644 --- a/test/test_checkpointer.jl +++ b/test/test_checkpointer.jl @@ -27,8 +27,7 @@ using Oceananigans.OutputWriters: Checkpointer set!(sea_ice.model, h=hi, ℵ=hi) # Create atmosphere and radiation - backend = JRA55NetCDFBackend(4) - atmosphere = JRA55PrescribedAtmosphere(arch; backend) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4) return OceanSeaIceModel(ocean, sea_ice; atmosphere) end diff --git a/test/test_ocean_only_model.jl b/test/test_ocean_only_model.jl index 8230e813d..23824435f 100644 --- a/test/test_ocean_only_model.jl +++ b/test/test_ocean_only_model.jl @@ -17,8 +17,7 @@ using Oceananigans.OrthogonalSphericalShellGrids data = Int[] pushdata(sim) = push!(data, iteration(sim)) add_callback!(ocean, pushdata) - backend = JRA55NetCDFBackend(4) - atmosphere = JRA55PrescribedAtmosphere(arch; backend) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4) radiation = Radiation(arch) coupled_model = OceanOnlyModel(ocean; atmosphere, radiation) Δt = 60 @@ -48,8 +47,7 @@ using Oceananigans.OrthogonalSphericalShellGrids free_surface = SplitExplicitFreeSurface(grid; substeps=20) ocean = ocean_simulation(grid; free_surface) - backend = JRA55NetCDFBackend(4) - atmosphere = JRA55PrescribedAtmosphere(arch; backend) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4) radiation = Radiation(arch) # Fluxes are computed when the model is constructed, so we just test that this works. diff --git a/test/test_ocean_sea_ice_model.jl b/test/test_ocean_sea_ice_model.jl index ad7a5112b..8df31e915 100644 --- a/test/test_ocean_sea_ice_model.jl +++ b/test/test_ocean_sea_ice_model.jl @@ -57,8 +57,7 @@ using ClimaSeaIce.Rheologies Tm = KernelFunctionOperation{Center, Center, Center}(kernel_melting_temperature, grid, liquidus, S) @test all(T .≥ Tm) - backend = JRA55NetCDFBackend(4) - atmosphere = JRA55PrescribedAtmosphere(arch; backend) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4) radiation = Radiation(arch) # Fluxes are computed when the model is constructed, so we just test that this works. diff --git a/test/test_sea_ice_ocean_heat_fluxes.jl b/test/test_sea_ice_ocean_heat_fluxes.jl index 7506d3d78..df7ce565a 100644 --- a/test/test_sea_ice_ocean_heat_fluxes.jl +++ b/test/test_sea_ice_ocean_heat_fluxes.jl @@ -199,8 +199,7 @@ end ocean = ocean_simulation(grid, momentum_advection=nothing, closure=nothing, tracer_advection=nothing) sea_ice = sea_ice_simulation(grid, ocean) - backend = JRA55NetCDFBackend(4) - atmosphere = JRA55PrescribedAtmosphere(arch; backend) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4) radiation = Radiation(arch) for sea_ice_ocean_heat_flux in [IceBathHeatFlux(), ThreeEquationHeatFlux()] @@ -257,8 +256,7 @@ end ocean = ocean_simulation(grid, momentum_advection=nothing, closure=nothing, tracer_advection=nothing) sea_ice = sea_ice_simulation(grid, ocean) - backend = JRA55NetCDFBackend(4) - atmosphere = JRA55PrescribedAtmosphere(arch; backend) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4) radiation = Radiation(arch) for sea_ice_ocean_heat_flux in [IceBathHeatFlux(), ThreeEquationHeatFlux()] @@ -401,8 +399,7 @@ end ocean = ocean_simulation(grid, momentum_advection=nothing, closure=nothing, tracer_advection=nothing) sea_ice = sea_ice_simulation(grid, ocean) - backend = JRA55NetCDFBackend(4) - atmosphere = JRA55PrescribedAtmosphere(arch; backend) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4) radiation = Radiation(arch) for sea_ice_ocean_heat_flux in [IceBathHeatFlux(), ThreeEquationHeatFlux()] @@ -452,8 +449,7 @@ end ocean = ocean_simulation(grid, momentum_advection=nothing, closure=nothing, tracer_advection=nothing) sea_ice = sea_ice_simulation(grid, ocean) - backend = JRA55NetCDFBackend(4) - atmosphere = JRA55PrescribedAtmosphere(arch; backend) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4) radiation = Radiation(arch) # Test with ThreeEquationHeatFlux (default) From bf25439c57fa02b08d417a178484b5abb685916d Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Mon, 20 Apr 2026 15:24:27 +0200 Subject: [PATCH 06/44] more test fixes --- src/DataWrangling/JRA55/JRA55_metadata.jl | 2 +- test/test_surface_fluxes.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/DataWrangling/JRA55/JRA55_metadata.jl b/src/DataWrangling/JRA55/JRA55_metadata.jl index f4cf58d26..f10310b89 100644 --- a/src/DataWrangling/JRA55/JRA55_metadata.jl +++ b/src/DataWrangling/JRA55/JRA55_metadata.jl @@ -108,7 +108,7 @@ end # Convenience functions dataset_variable_name(data::JRA55Metadata) = JRA55_dataset_variable_names[data.name] -location(::JRA55Metadata) = (Center, Center, Center) +location(::JRA55Metadata) = (Center, Center, Nothing) available_variables(::JRA55Dataset) = JRA55_variable_names diff --git a/test/test_surface_fluxes.jl b/test/test_surface_fluxes.jl index a993f6dae..66eafd021 100644 --- a/test/test_surface_fluxes.jl +++ b/test/test_surface_fluxes.jl @@ -46,7 +46,7 @@ end bottom_drag_coefficient = 0) dates = all_dates(RepeatYearJRA55(), :temperature) - atmosphere = JRA55PrescribedAtmosphere(arch, Float64; end_date=dates[2], backend = InMemory()) + atmosphere = JRA55PrescribedAtmosphere(arch; end_date=dates[2]) CUDA.@allowscalar begin h = atmosphere.surface_layer_height From 6b2830d9eab4fef34ce62ee41629026377a6c065 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Mon, 20 Apr 2026 16:44:13 +0200 Subject: [PATCH 07/44] just pass the native grid --- src/DataWrangling/JRA55/JRA55_metadata.jl | 25 +++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/src/DataWrangling/JRA55/JRA55_metadata.jl b/src/DataWrangling/JRA55/JRA55_metadata.jl index f10310b89..2dd20d223 100644 --- a/src/DataWrangling/JRA55/JRA55_metadata.jl +++ b/src/DataWrangling/JRA55/JRA55_metadata.jl @@ -22,6 +22,7 @@ import NumericalEarth.DataWrangling: all_dates, default_inpainting, getfilename, z_interfaces, + native_grid, longitude_interfaces, latitude_interfaces @@ -40,10 +41,30 @@ default_download_directory(::JRA55Dataset) = download_JRA55_cache Base.size(::JRA55Dataset, variable) = (640, 320, 1) -z_interfaces(::JRA55Metadata) = (0, 10) longitude_interfaces(::JRA55Metadata) = (0, 360) latitude_interfaces(::JRA55Metadata) = (-90, 90) +function native_grid(metadata::JRA55Metadata, arch=CPU(); halo = (3, 3, 3)) + Nx, Ny, Nz, _ = size(metadata) + + FT = eltype(metadata) + + longitude = longitude_interfaces(metadata) + latitude = latitude_interfaces(metadata) + + bbox = metadata.bounding_box + if !isnothing(bbox) + longitude, Nx = restrict(bbox.longitude, longitude, Nx) + latitude, Ny = restrict(bbox.latitude, latitude, Ny) + end + + grid = LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny, Nz), + halo, longitude, latitude, + topology = (Periodic, Bounded, Flat)) + + return grid +end + # JRA55 is a spatially 2D dataset is_three_dimensional(data::JRA55Metadata) = false @@ -108,7 +129,7 @@ end # Convenience functions dataset_variable_name(data::JRA55Metadata) = JRA55_dataset_variable_names[data.name] -location(::JRA55Metadata) = (Center, Center, Nothing) +location(::JRA55Metadata) = (Center, Center, Center) available_variables(::JRA55Dataset) = JRA55_variable_names From 009327d1980e7e17ca44b95cb6b2b5b935165800 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Mon, 20 Apr 2026 16:57:23 +0200 Subject: [PATCH 08/44] make sure we use 2 sizes and 2 halos --- src/DataWrangling/JRA55/JRA55_metadata.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/DataWrangling/JRA55/JRA55_metadata.jl b/src/DataWrangling/JRA55/JRA55_metadata.jl index 2dd20d223..f7b48e391 100644 --- a/src/DataWrangling/JRA55/JRA55_metadata.jl +++ b/src/DataWrangling/JRA55/JRA55_metadata.jl @@ -44,7 +44,7 @@ Base.size(::JRA55Dataset, variable) = (640, 320, 1) longitude_interfaces(::JRA55Metadata) = (0, 360) latitude_interfaces(::JRA55Metadata) = (-90, 90) -function native_grid(metadata::JRA55Metadata, arch=CPU(); halo = (3, 3, 3)) +function native_grid(metadata::JRA55Metadata, arch=CPU(); halo = (3, 3)) Nx, Ny, Nz, _ = size(metadata) FT = eltype(metadata) @@ -58,7 +58,7 @@ function native_grid(metadata::JRA55Metadata, arch=CPU(); halo = (3, 3, 3)) latitude, Ny = restrict(bbox.latitude, latitude, Ny) end - grid = LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny, Nz), + grid = LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny), halo, longitude, latitude, topology = (Periodic, Bounded, Flat)) @@ -102,11 +102,8 @@ end # File name generation specific to each Dataset dataset # Note that `RepeatYearJRA55` has only one file associated, so the filename # is independent of the date. Override the multi-date fallback to return a plain String. -metadata_filename(::RepeatYearJRA55, name, date, bounding_box) = - "RYF." * JRA55_dataset_variable_names[name] * ".1990_1991.nc" - -build_filename(::RepeatYearJRA55, name, dates::AbstractArray, bounding_box) = - "RYF." * JRA55_dataset_variable_names[name] * ".1990_1991.nc" +metadata_filename(::RepeatYearJRA55, name, date, bounding_box) = "RYF." * JRA55_dataset_variable_names[name] * ".1990_1991.nc" +build_filename(::RepeatYearJRA55, name, dates::AbstractArray, bounding_box) = "RYF." * JRA55_dataset_variable_names[name] * ".1990_1991.nc" function metadata_filename(::MultiYearJRA55, name, date, bounding_box) shortname = JRA55_dataset_variable_names[name] From cd97aa9e545b4043c948193ad70e0d0fa7b779a3 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Mon, 20 Apr 2026 16:59:21 +0200 Subject: [PATCH 09/44] inner backend --- src/DataWrangling/prefetching_backend.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/DataWrangling/prefetching_backend.jl b/src/DataWrangling/prefetching_backend.jl index 717170cb2..8bc10a26b 100644 --- a/src/DataWrangling/prefetching_backend.jl +++ b/src/DataWrangling/prefetching_backend.jl @@ -17,43 +17,43 @@ import Oceananigans.OutputReaders: new_backend import Oceananigans.Fields: set! mutable struct PrefetchingBackend{B<:DatasetBackend, F<:FieldTimeSeries} <: AbstractInMemoryBackend{Int} - inner :: B + inner_backend :: B pending :: Union{Task, Nothing} buffer_fts :: F next_start :: Int end -PrefetchingBackend(inner::DatasetBackend, buffer_fts::FieldTimeSeries) = PrefetchingBackend{typeof(inner), typeof(buffer_fts)}(inner, nothing, buffer_fts, 0) +PrefetchingBackend(inner_backend::DatasetBackend, buffer_fts::FieldTimeSeries) = PrefetchingBackend{typeof(inner_backend), typeof(buffer_fts)}(inner_backend, nothing, buffer_fts, 0) # `:buffer_fts` deliberately warned upon — see race invariant in preamble. function Base.getproperty(p::PrefetchingBackend, name::Symbol) - if name in (:inner, :pending, :next_start) + if name in (:inner_backend, :pending, :next_start) return getfield(p, name) elseif name == :buffer_fts @warn "`buffer_fts` is an inner auxiliary field touched on an hot-loop separate task. " * "Mutating it manually might lead to undefined behavior. It is recommended not modifying it." return getfield(p, name) else - return getproperty(getfield(p, :inner), name) + return getproperty(getfield(p, :inner_backend), name) end end -Base.length(p::PrefetchingBackend) = length(p.inner) -Base.summary(p::PrefetchingBackend) = string("PrefetchingBackend(", p.inner.start, ", ", p.inner.length, "; pending=", !isnothing(getfield(p, :pending)), ")") +Base.length(p::PrefetchingBackend) = length(p.inner_backend) +Base.summary(p::PrefetchingBackend) = string("PrefetchingBackend(", p.inner_backend.start, ", ", p.inner_backend.length, "; pending=", !isnothing(getfield(p, :pending)), ")") # Mutate in place rather than constructing a fresh wrapper — keeps the # `pending`/`buffer_fts`/`next_start` mutable state in exactly one object. function new_backend(p::PrefetchingBackend, start, length) - setfield!(p, :inner, new_backend(getfield(p, :inner), start, length)) + setfield!(p, :inner_backend, new_backend(getfield(p, :inner_backend), start, length)) return p end -Adapt.adapt_structure(to, p::PrefetchingBackend) = Adapt.adapt(to, getfield(p, :inner)) +Adapt.adapt_structure(to, p::PrefetchingBackend) = Adapt.adapt(to, getfield(p, :inner_backend)) const PrefetchingFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:PrefetchingBackend} function set!(fts::PrefetchingFTS, backend::PrefetchingBackend = fts.backend) - needed_start = getfield(backend, :inner).start + needed_start = getfield(backend, :inner_backend).start pending = getfield(backend, :pending) pending_start = getfield(backend, :next_start) buffer_fts = getfield(backend, :buffer_fts) @@ -65,7 +65,7 @@ function set!(fts::PrefetchingFTS, backend::PrefetchingBackend = fts.backend) wait(pending) else !isnothing(pending) && wait(pending) - Nm = length(getfield(backend, :inner)) + Nm = length(getfield(backend, :inner_backend)) buffer_fts.backend = new_backend(buffer_fts.backend, needed_start, Nm) set!(buffer_fts) end @@ -74,12 +74,12 @@ function set!(fts::PrefetchingFTS, backend::PrefetchingBackend = fts.backend) # Time-indexing-aware next-window prediction: `time_index` wraps # via mod1 for Cyclical and clamps to Nt for Linear/Clamp. - Nm = length(getfield(backend, :inner)) + Nm = length(getfield(backend, :inner_backend)) Nt = length(fts.times) new_next = time_index(buffer_fts.backend, fts.time_indexing, Nt, Nm + 1) + # Linear/Clamp at end-of-data: window can't advance, no prefetch. if new_next == needed_start - # Linear/Clamp at end-of-data: window can't advance, no prefetch. setfield!(backend, :next_start, 0) return nothing end From d9ec6e0d0738e5d561411849a2515d0608a18aca Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Mon, 20 Apr 2026 17:46:48 +0200 Subject: [PATCH 10/44] pop flat elements --- src/DataWrangling/JRA55/JRA55_metadata.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/DataWrangling/JRA55/JRA55_metadata.jl b/src/DataWrangling/JRA55/JRA55_metadata.jl index f7b48e391..d7608c262 100644 --- a/src/DataWrangling/JRA55/JRA55_metadata.jl +++ b/src/DataWrangling/JRA55/JRA55_metadata.jl @@ -3,6 +3,7 @@ using Dates using Downloads using Oceananigans.DistributedComputations +using Oceananigans.Grids: pop_flat_elements using NumericalEarth.DataWrangling using NumericalEarth.DataWrangling: Metadata, metadata_path, download_progress, AnyDateTime, DatasetBackend @@ -48,7 +49,7 @@ function native_grid(metadata::JRA55Metadata, arch=CPU(); halo = (3, 3)) Nx, Ny, Nz, _ = size(metadata) FT = eltype(metadata) - + halo = pop_flat_elements(halo, (Periodic, Bounded, Flat)) longitude = longitude_interfaces(metadata) latitude = latitude_interfaces(metadata) From b8786d783c1780045760706b99d65d4da7f5297e Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Mon, 20 Apr 2026 18:33:56 +0200 Subject: [PATCH 11/44] put a guard on the time_indices_in_memory --- src/DataWrangling/metadata_field_time_series.jl | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/DataWrangling/metadata_field_time_series.jl b/src/DataWrangling/metadata_field_time_series.jl index 61eb57c6a..5db749e8f 100644 --- a/src/DataWrangling/metadata_field_time_series.jl +++ b/src/DataWrangling/metadata_field_time_series.jl @@ -52,22 +52,27 @@ function FieldTimeSeries(metadata::Metadata, grid::AbstractGrid; # Detect "the user's grid IS the native grid" structurally on_native_grid = grid == native_grid(metadata, architecture(grid)) + times = native_times(metadata) + + # Make sure we do not use more indices then the ones available! + if length(times) < time_indices_in_memory + time_indices_in_memory = length(times) + end inpainting isa Int && (inpainting = NearestNeighborInpainting(inpainting)) - inner = DatasetBackend(time_indices_in_memory, metadata; on_native_grid, inpainting, cache_inpainted_data) + inner_backend = DatasetBackend(time_indices_in_memory, metadata; on_native_grid, inpainting, cache_inpainted_data) - times = native_times(metadata) loc = LX, LY, LZ = location(metadata) boundary_conditions = FieldBoundaryConditions(grid, instantiate.(loc)) if prefetch Threads.nthreads() < 2 && @warn "prefetch=true is a no-op with JULIA_NUM_THREADS=$(Threads.nthreads()); start Julia with ≥ 2 threads." # Buffer FTS is allocated once and reused per reload (see prefetching_backend.jl). - buffer_inner = new_backend(inner, 1, time_indices_in_memory) + buffer_inner = new_backend(inner_backend, 1, time_indices_in_memory) buffer_fts = FieldTimeSeries{LX, LY, LZ}(grid, times; backend=buffer_inner, time_indexing, boundary_conditions) - backend = PrefetchingBackend(inner, buffer_fts) + backend = PrefetchingBackend(inner_backend, buffer_fts) else - backend = inner + backend = inner_backend end fts = FieldTimeSeries{LX, LY, LZ}(grid, times; backend, time_indexing, boundary_conditions) From a06a65641acf408eea2c0d8b36b9ae5c1845625f Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Mon, 20 Apr 2026 18:50:01 +0200 Subject: [PATCH 12/44] these variables have a different size --- src/DataWrangling/JRA55/JRA55_field_time_series.jl | 1 - src/DataWrangling/JRA55/JRA55_metadata.jl | 8 +++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/DataWrangling/JRA55/JRA55_field_time_series.jl b/src/DataWrangling/JRA55/JRA55_field_time_series.jl index db59a5a20..4933fd5ba 100644 --- a/src/DataWrangling/JRA55/JRA55_field_time_series.jl +++ b/src/DataWrangling/JRA55/JRA55_field_time_series.jl @@ -195,7 +195,6 @@ function set!(fts::JRA55NetCDFFTSRepeatYear, backend=fts.backend) ds = Dataset(path) # Nodes at the variable location - λc = ds["lon"][:] φc = ds["lat"][:] LX, LY, LZ = location(fts) diff --git a/src/DataWrangling/JRA55/JRA55_metadata.jl b/src/DataWrangling/JRA55/JRA55_metadata.jl index d7608c262..4cb2338cc 100644 --- a/src/DataWrangling/JRA55/JRA55_metadata.jl +++ b/src/DataWrangling/JRA55/JRA55_metadata.jl @@ -40,7 +40,13 @@ const MultiYearJRA55Metadatum = Metadatum{<:MultiYearJRA55} default_download_directory(::JRA55Dataset) = download_JRA55_cache -Base.size(::JRA55Dataset, variable) = (640, 320, 1) +function Base.size(::JRA55Dataset, variable) + if variable ∈ [:river_freshwater_flux, :iceberg_freshwater_flux] + (1440, 720, 1) + else + (640, 320, 1) + end +end longitude_interfaces(::JRA55Metadata) = (0, 360) latitude_interfaces(::JRA55Metadata) = (-90, 90) From e7a6070d10d11e2747b0386f273071755f122229 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Mon, 20 Apr 2026 18:57:57 +0200 Subject: [PATCH 13/44] discontinue this example --- docs/make.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/make.jl b/docs/make.jl index f3f332f8d..7bcf21270 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -29,7 +29,7 @@ mkpath(OUTPUT_DIR) # Set `build_always = false` for long-running examples that should only be built # on pushes to `main`/tags, or when the `build all examples` label is added to a PR. examples = [ - Example("Single-column ocean simulation", "single_column_os_papa_simulation", true), + # Example("Single-column ocean simulation", "single_column_os_papa_simulation", true), Example("One-degree ocean--sea ice simulation", "one_degree_simulation", false), Example("Near-global ocean simulation", "near_global_ocean_simulation", false), Example("Global climate simulation", "global_climate_simulation", false), From 58c73ad8ccad9af3ae3eef568612f2f1ce1b98f7 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Mon, 20 Apr 2026 19:30:18 +0200 Subject: [PATCH 14/44] fix surface flux tests --- test/test_surface_fluxes.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_surface_fluxes.jl b/test/test_surface_fluxes.jl index 66eafd021..29f95a5b2 100644 --- a/test/test_surface_fluxes.jl +++ b/test/test_surface_fluxes.jl @@ -32,7 +32,7 @@ end @testset "Test surface fluxes" begin for arch in test_architectures - grid = LatitudeLongitudeGrid(arch; + grid = LatitudeLongitudeGrid(arch, Float32; size = 1, latitude = 10, longitude = 10, @@ -53,7 +53,7 @@ end pᵃᵗ = atmosphere.pressure[1][1, 1, 1] Tᵃᵗ = 15 + celsius_to_kelvin - qᵃᵗ = 0.003 + qᵃᵗ = Float32(0.003) uᵃᵗ = atmosphere.velocities.u[1][1, 1, 1] vᵃᵗ = atmosphere.velocities.v[1][1, 1, 1] @@ -192,7 +192,7 @@ end bottom_drag_coefficient = 0.0) dates = all_dates(RepeatYearJRA55(), :temperature) - atmosphere = JRA55PrescribedAtmosphere(arch; end_date=dates[2], backend = InMemory()) + atmosphere = JRA55PrescribedAtmosphere(arch; end_date=dates[2]) fill!(ocean.model.tracers.T, -2.0) From 2efbaade5ef5e01f852c9a102b13c5918300e94e Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Mon, 20 Apr 2026 20:25:49 +0200 Subject: [PATCH 15/44] fix tests --- src/DataWrangling/JRA55/JRA55_metadata.jl | 20 +++++++++++++++++--- test/test_jra55.jl | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/DataWrangling/JRA55/JRA55_metadata.jl b/src/DataWrangling/JRA55/JRA55_metadata.jl index 4cb2338cc..c09a04045 100644 --- a/src/DataWrangling/JRA55/JRA55_metadata.jl +++ b/src/DataWrangling/JRA55/JRA55_metadata.jl @@ -56,8 +56,8 @@ function native_grid(metadata::JRA55Metadata, arch=CPU(); halo = (3, 3)) FT = eltype(metadata) halo = pop_flat_elements(halo, (Periodic, Bounded, Flat)) - longitude = longitude_interfaces(metadata) - latitude = latitude_interfaces(metadata) + + longitude, latitude = jra55_native_interfaces(metadata_path(first(metadata))) bbox = metadata.bounding_box if !isnothing(bbox) @@ -66,12 +66,26 @@ function native_grid(metadata::JRA55Metadata, arch=CPU(); halo = (3, 3)) end grid = LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny), - halo, longitude, latitude, + halo, longitude, latitude, topology = (Periodic, Bounded, Flat)) return grid end +function jra55_native_interfaces(path) + ds = Dataset(path) + λn = Array{Float64}(ds["lon_bnds"][1, :]) + φn = Array{Float64}(ds["lat_bnds"][1, :]) + close(ds) + + # `lon_bnds` / `lat_bnds` hold only the left/lower interface, + # so we need to append the trailing interface. + push!(λn, λn[1] + 360) + push!(φn, 90) + + return λn, φn +end + # JRA55 is a spatially 2D dataset is_three_dimensional(data::JRA55Metadata) = false diff --git a/test/test_jra55.jl b/test/test_jra55.jl index e3fe40e2e..1f1c614db 100644 --- a/test/test_jra55.jl +++ b/test/test_jra55.jl @@ -118,7 +118,7 @@ using NumericalEarth.DataWrangling: compute_native_date_range name = :downwelling_shortwave_radiation dates = NumericalEarth.DataWrangling.all_dates(JRA55.RepeatYearJRA55(), name) end_date = dates[3] - JRA55_fts = FieldTimeSeries(Metadata(name; dataset=JRA55.RepeatYearJRA55(), end_date), arch) + JRA55_fts = FieldTimeSeries(Metadata(name; dataset=JRA55.RepeatYearJRA55(), end_date), arch; time_indices_in_memory=3) # Make target grid and field resolution = 1 # degree, eg 1/4 From 0425328c4cd9f3b3799588f586c525a77825ea6c Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Mon, 20 Apr 2026 22:59:49 +0200 Subject: [PATCH 16/44] Update src/DataWrangling/restoring.jl Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/DataWrangling/restoring.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DataWrangling/restoring.jl b/src/DataWrangling/restoring.jl index 1aaa59af8..966c15827 100644 --- a/src/DataWrangling/restoring.jl +++ b/src/DataWrangling/restoring.jl @@ -188,7 +188,7 @@ Keyword Arguments - `cache_inpainted_data`: If `true`, the data is cached to disk after inpainting for later retrieving. Default: `true`. -- `prefetch`: If `true`, hide the next reload's I/O behind compute via a background `Threads.@spawn`. +- `prefetch`: If `true`, hide the next reload's I/O behind compute via a background `Threads.@spawn` task. Intended for long-lived FTSes; short-lived ones leak one prefetch task. Default: `false`. """ function DatasetRestoring(metadata::Metadata, From 3b421673fb92851a72dcc30998057933df69959c Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Mon, 20 Apr 2026 23:00:02 +0200 Subject: [PATCH 17/44] Update src/DataWrangling/prefetching_backend.jl Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/DataWrangling/prefetching_backend.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/DataWrangling/prefetching_backend.jl b/src/DataWrangling/prefetching_backend.jl index 8bc10a26b..40edc84e1 100644 --- a/src/DataWrangling/prefetching_backend.jl +++ b/src/DataWrangling/prefetching_backend.jl @@ -30,8 +30,8 @@ function Base.getproperty(p::PrefetchingBackend, name::Symbol) if name in (:inner_backend, :pending, :next_start) return getfield(p, name) elseif name == :buffer_fts - @warn "`buffer_fts` is an inner auxiliary field touched on an hot-loop separate task. " * - "Mutating it manually might lead to undefined behavior. It is recommended not modifying it." + @warn "`buffer_fts` is an inner auxiliary field touched in a hot loop by a separate task. " * + "Mutating it manually might lead to undefined behavior. It is recommended not to modify it." return getfield(p, name) else return getproperty(getfield(p, :inner_backend), name) From 00e4273fdb8d765d036e0f10645f175117c2dab5 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Wed, 22 Apr 2026 17:12:10 +0200 Subject: [PATCH 18/44] Fix prefetch predictor off-by-one (slide by Nm - 1, not Nm) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `PrefetchingBackend` predicted the next window to load as `needed_start + Nm`, assuming the sliding window advances by a whole window every reload. It doesn't: `update_field_time_series!` sets `start = n₁`, and the first `n₂ = n₁ + 1` outside the current window triggers the reload, so the new `start` is `needed_start + Nm - 1`. Every reload therefore hit the COLD branch — wait on a pending prefetch whose target was off by one, then do the full synchronous load anyway. On OMIP 1° this produced a ~15 s wall-time spike at every window boundary, strictly worse than `prefetch = false`. Pass `Nm` (not `Nm + 1`) to `time_index` so the predictor matches the actual slide. Updates `test/test_jra55.jl` to drive the PrefetchingFTS with `Nb - 1` slides — the previous sequence `(Nb+1, 2Nb+1, …)` happened to agree with the buggy predictor and so never exposed it. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/DataWrangling/prefetching_backend.jl | 10 +++++++-- test/test_jra55.jl | 28 ++++++++++++++---------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/src/DataWrangling/prefetching_backend.jl b/src/DataWrangling/prefetching_backend.jl index 40edc84e1..23b97c117 100644 --- a/src/DataWrangling/prefetching_backend.jl +++ b/src/DataWrangling/prefetching_backend.jl @@ -73,10 +73,16 @@ function set!(fts::PrefetchingFTS, backend::PrefetchingBackend = fts.backend) copyto!(parent(fts.data), parent(buffer_fts.data)) # Time-indexing-aware next-window prediction: `time_index` wraps - # via mod1 for Cyclical and clamps to Nt for Linear/Clamp. + # via mod1 for Cyclical and clamps to Nt for Linear/Clamp. The next + # reload fires the first time `n₂ = n₁ + 1` falls outside the current + # window, which happens when `n₁` hits the LAST in-memory index, so + # the new `start` is `needed_start + Nm - 1`, not `needed_start + Nm` + # (`update_field_time_series!` sets `start = n₁`). Passing `Nm` + # instead of `Nm + 1` yields that prediction and keeps every reload + # on the HOT path. Nm = length(getfield(backend, :inner_backend)) Nt = length(fts.times) - new_next = time_index(buffer_fts.backend, fts.time_indexing, Nt, Nm + 1) + new_next = time_index(buffer_fts.backend, fts.time_indexing, Nt, Nm) # Linear/Clamp at end-of-data: window can't advance, no prefetch. if new_next == needed_start diff --git a/test/test_jra55.jl b/test/test_jra55.jl index 1f1c614db..97d63df19 100644 --- a/test/test_jra55.jl +++ b/test/test_jra55.jl @@ -72,23 +72,27 @@ using NumericalEarth.DataWrangling: compute_native_date_range @test pf_fts.backend isa NumericalEarth.DataWrangling.PrefetchingBackend @test parent(pf_fts.data) == parent(ref_fts.data) # cold load alignment - @test pf_fts.backend.next_start == Nb + 1 # next prefetch scheduled - - # Reload sequence: - # * Nb+1, 2Nb+1 → straight hot-path advances - # * Nt-Nb+1 → places the next prefetch's window across - # the end-of-times boundary, exercising the - # `mod1(start+Nm, Nt)` wrap when scheduling - # the prefetch - # * 1 → consumes that wrapped prefetch (hot path) - # at the start of the cycle - for next_start in (Nb + 1, 2Nb + 1, Nt - Nb + 1, 1) + @test pf_fts.backend.next_start == Nb # next prefetch scheduled + + # Reload sequence mirrors what `update_field_time_series!` + # produces at run time — a slide of `Nb - 1` per reload, because + # the first `n₂ = n₁ + 1` outside the window triggers it and the + # new `start` is set to `n₁` (the last in-memory index of the + # previous window): + # * Nb, 2Nb - 1 → hot-path advances; each matches the + # predictor `start + Nm - 1` + # * Nt - Nb + 1 → arbitrary jump (cold-path); schedules a + # prefetch whose window crosses the end of + # times, exercising `mod1(start+Nm-1, Nt)` + # * Nt → consumes that wrapped prefetch (hot + # path) at the very end of the cycle + for next_start in (Nb, 2Nb - 1, Nt - Nb + 1, Nt) ref_fts.backend = Oceananigans.OutputReaders.new_backend(ref_fts.backend, next_start, Nb) pf_fts.backend = Oceananigans.OutputReaders.new_backend(pf_fts.backend, next_start, Nb) set!(ref_fts) set!(pf_fts) @test parent(pf_fts.data) == parent(ref_fts.data) - @test pf_fts.backend.next_start == mod1(next_start + Nb, Nt) + @test pf_fts.backend.next_start == mod1(next_start + Nb - 1, Nt) end end From 92d98e739de315507ce90d8536bdd16bc653e873 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Thu, 23 Apr 2026 10:31:37 +0200 Subject: [PATCH 19/44] Recover from failed prefetch instead of killing the simulation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A Threads.@spawn-ed prefetch task that fails (transient I/O error, FS hiccup, staging race, etc.) previously left a Failed `Task` in the backend's `:pending` slot. The next `set!` on that FTS called `wait(pending)`, which rethrew the exception as a `TaskFailedException` and aborted the whole run. Demote a failed hot-path `wait` to the existing synchronous cold path: the main thread does one extra blocking load, logs a @warn with dataset/variable context, and the simulation continues. A stale-but- failed drain in the cold branch is swallowed outright — we're about to reload from scratch anyway, and the spawn site already logged the failure with full backtrace. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/DataWrangling/prefetching_backend.jl | 31 +++++++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/src/DataWrangling/prefetching_backend.jl b/src/DataWrangling/prefetching_backend.jl index 23b97c117..5ba70f199 100644 --- a/src/DataWrangling/prefetching_backend.jl +++ b/src/DataWrangling/prefetching_backend.jl @@ -61,10 +61,33 @@ function set!(fts::PrefetchingFTS, backend::PrefetchingBackend = fts.backend) # Cleared up-front so a failed prefetch isn't re-thrown on every later set!. setfield!(backend, :pending, nothing) - if !isnothing(pending) && pending_start == needed_start - wait(pending) - else - !isnothing(pending) && wait(pending) + # Hot path: the pending prefetch already targets `needed_start`. Wait on + # it (typically a no-op — the background load finished while compute was + # running). A failed prefetch demotes to the synchronous load below, so a + # transient I/O error (a brief FS hiccup, a staging race, etc.) doesn't + # kill the simulation. The spawn site has already logged the exception + # with full variable/window context, so we just print a short warning + # here. + hot = !isnothing(pending) && pending_start == needed_start + if hot + try + wait(pending) + catch + m = buffer_fts.backend.metadata + @warn "PrefetchingBackend: pending prefetch failed; falling back to synchronous load" dataset=typeof(m.dataset) variable=m.name + hot = false + end + elseif !isnothing(pending) + # Stale prefetch targets a different window; drain it and swallow + # any failure — we're about to reload from scratch anyway and the + # spawn site has already logged the failure if there was one. + try + wait(pending) + catch + end + end + + if !hot Nm = length(getfield(backend, :inner_backend)) buffer_fts.backend = new_backend(buffer_fts.backend, needed_start, Nm) set!(buffer_fts) From ea52120080dc7da58378b47dacf3e9942017a459 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Thu, 23 Apr 2026 12:06:40 +0200 Subject: [PATCH 20/44] remove dataset prefetch --- src/DataWrangling/DataWrangling.jl | 1 - src/DataWrangling/ECCO/ECCO_atmosphere.jl | 3 +- .../JRA55/JRA55_field_time_series.jl | 3 +- .../JRA55/JRA55_prescribed_atmosphere.jl | 9 +- .../metadata_field_time_series.jl | 19 +-- src/DataWrangling/prefetching_backend.jl | 130 ------------------ src/DataWrangling/restoring.jl | 9 +- test/test_checkpointer.jl | 58 -------- test/test_jra55.jl | 38 ----- 9 files changed, 10 insertions(+), 260 deletions(-) delete mode 100644 src/DataWrangling/prefetching_backend.jl diff --git a/src/DataWrangling/DataWrangling.jl b/src/DataWrangling/DataWrangling.jl index 9591bf3bc..38be1665c 100644 --- a/src/DataWrangling/DataWrangling.jl +++ b/src/DataWrangling/DataWrangling.jl @@ -199,7 +199,6 @@ default_mask_value(dataset) = NaN include("metadata.jl") include("metadata_field.jl") include("dataset_backend.jl") -include("prefetching_backend.jl") include("metadata_field_time_series.jl") include("inpainting.jl") include("restoring.jl") diff --git a/src/DataWrangling/ECCO/ECCO_atmosphere.jl b/src/DataWrangling/ECCO/ECCO_atmosphere.jl index 7858dbfcb..56813ca4e 100644 --- a/src/DataWrangling/ECCO/ECCO_atmosphere.jl +++ b/src/DataWrangling/ECCO/ECCO_atmosphere.jl @@ -27,12 +27,11 @@ function ECCOPrescribedAtmosphere(architecture = CPU(), FT = Float32; end_date = last_date(dataset, :air_temperature), dir = default_download_directory(dataset), time_indexing = Cyclical(), - prefetch = false, time_indices_in_memory = 10, surface_layer_height = 2, # meters other_kw...) - kw = (; time_indexing, time_indices_in_memory, prefetch) + kw = (; time_indexing, time_indices_in_memory) kw = merge(kw, other_kw) ecco_fts(name) = FieldTimeSeries(Metadata(name; dataset, start_date, end_date, dir), architecture; kw...) diff --git a/src/DataWrangling/JRA55/JRA55_field_time_series.jl b/src/DataWrangling/JRA55/JRA55_field_time_series.jl index 4933fd5ba..f104348b2 100644 --- a/src/DataWrangling/JRA55/JRA55_field_time_series.jl +++ b/src/DataWrangling/JRA55/JRA55_field_time_series.jl @@ -84,8 +84,7 @@ function JRA55FieldTimeSeries(args...; kwargs...) FieldTimeSeries(Metadata(:variable_name; dataset=RepeatYearJRA55(), start_date, end_date, dir), architecture; - time_indices_in_memory = N, - prefetch = false) + time_indices_in_memory = N) The `InMemory()` backend is no longer supported for JRA55; pass `time_indices_in_memory = length(metadata)` to keep the whole diff --git a/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl b/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl index a5222411a..80ec7a814 100644 --- a/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl +++ b/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl @@ -11,7 +11,6 @@ JRA55PrescribedAtmosphere(arch::Distributed; kw...) = dir = download_JRA55_cache, time_indices_in_memory = 10, time_indexing = Cyclical(), - prefetch = false, surface_layer_height = 10, # meters include_rivers_and_icebergs = false, other_kw...) @@ -19,10 +18,7 @@ JRA55PrescribedAtmosphere(arch::Distributed; kw...) = Return a [`PrescribedAtmosphere`](@ref) representing JRA55 reanalysis data. Each atmospheric field is constructed via `FieldTimeSeries(::JRA55Metadata)`, which uses a `DatasetBackend` parameterised by JRA55 metadata so that the -JRA55-specific `set!` (chunked-yearly NetCDF) is dispatched. With -`prefetch = true` each variable's next sliding window is loaded -asynchronously on a background thread so the reload spike (~15 s on a -240-step window across 9 variables) is hidden behind compute. +JRA55-specific `set!` (chunked-yearly NetCDF) is dispatched. """ function JRA55PrescribedAtmosphere(architecture = CPU(); dataset = RepeatYearJRA55(), @@ -31,12 +27,11 @@ function JRA55PrescribedAtmosphere(architecture = CPU(); dir = download_JRA55_cache, time_indices_in_memory = 10, time_indexing = Cyclical(), - prefetch = false, surface_layer_height = 10, # meters include_rivers_and_icebergs = false, other_kw...) - kw = (; time_indexing, time_indices_in_memory, prefetch) + kw = (; time_indexing, time_indices_in_memory) kw = merge(kw, other_kw) jra55_fts(name) = FieldTimeSeries(Metadata(name; dataset, start_date, end_date, dir), architecture; kw...) diff --git a/src/DataWrangling/metadata_field_time_series.jl b/src/DataWrangling/metadata_field_time_series.jl index 5db749e8f..5228cb00f 100644 --- a/src/DataWrangling/metadata_field_time_series.jl +++ b/src/DataWrangling/metadata_field_time_series.jl @@ -45,36 +45,25 @@ function FieldTimeSeries(metadata::Metadata, grid::AbstractGrid; time_indices_in_memory = 2, time_indexing = Cyclical(), inpainting = default_inpainting(metadata), - cache_inpainted_data = true, - prefetch = false) + cache_inpainted_data = true) download_dataset(metadata) - # Detect "the user's grid IS the native grid" structurally + # Detect "the user's grid IS the native grid" structurally on_native_grid = grid == native_grid(metadata, architecture(grid)) times = native_times(metadata) - + # Make sure we do not use more indices then the ones available! if length(times) < time_indices_in_memory time_indices_in_memory = length(times) end inpainting isa Int && (inpainting = NearestNeighborInpainting(inpainting)) - inner_backend = DatasetBackend(time_indices_in_memory, metadata; on_native_grid, inpainting, cache_inpainted_data) + backend = DatasetBackend(time_indices_in_memory, metadata; on_native_grid, inpainting, cache_inpainted_data) loc = LX, LY, LZ = location(metadata) boundary_conditions = FieldBoundaryConditions(grid, instantiate.(loc)) - if prefetch - Threads.nthreads() < 2 && @warn "prefetch=true is a no-op with JULIA_NUM_THREADS=$(Threads.nthreads()); start Julia with ≥ 2 threads." - # Buffer FTS is allocated once and reused per reload (see prefetching_backend.jl). - buffer_inner = new_backend(inner_backend, 1, time_indices_in_memory) - buffer_fts = FieldTimeSeries{LX, LY, LZ}(grid, times; backend=buffer_inner, time_indexing, boundary_conditions) - backend = PrefetchingBackend(inner_backend, buffer_fts) - else - backend = inner_backend - end - fts = FieldTimeSeries{LX, LY, LZ}(grid, times; backend, time_indexing, boundary_conditions) set!(fts) diff --git a/src/DataWrangling/prefetching_backend.jl b/src/DataWrangling/prefetching_backend.jl deleted file mode 100644 index 5ba70f199..000000000 --- a/src/DataWrangling/prefetching_backend.jl +++ /dev/null @@ -1,130 +0,0 @@ -# Asynchronous prefetch wrapper around `DatasetBackend`. Hides the next sliding-window's I/O behind the current window's compute -# by reading into a buffer `FieldTimeSeries` on a `Threads.@spawn`-ed task. Every `set!` either copies from the prefetched buffer -# (hot) or loads synchronously (cold), then schedules the next window's read. The buffer is allocated once at FTS construction and -# reused for every reload — zero allocation per `set!`. -# -# Race invariant: between the spawn at the end of one `set!` and the `wait` at the start of the next, the worker is mutating -# `buffer_fts.data`. No code outside `set!(::PrefetchingFTS)` may touch `buffer_fts` in that window. -# Two enforcement points: `:buffer_fts` is not forwarded by `getproperty`, and `Adapt.adapt_structure` returns -# only the inner backend. -# -# Requires `JULIA_NUM_THREADS ≥ 2` to actually overlap; one thread makes the spawn cooperatively-scheduled and the optimisation a no-op. - -using Oceananigans.OutputReaders: AbstractInMemoryBackend, FlavorOfFTS, FieldTimeSeries, time_index -using Oceananigans.Fields: location - -import Oceananigans.OutputReaders: new_backend -import Oceananigans.Fields: set! - -mutable struct PrefetchingBackend{B<:DatasetBackend, F<:FieldTimeSeries} <: AbstractInMemoryBackend{Int} - inner_backend :: B - pending :: Union{Task, Nothing} - buffer_fts :: F - next_start :: Int -end - -PrefetchingBackend(inner_backend::DatasetBackend, buffer_fts::FieldTimeSeries) = PrefetchingBackend{typeof(inner_backend), typeof(buffer_fts)}(inner_backend, nothing, buffer_fts, 0) - -# `:buffer_fts` deliberately warned upon — see race invariant in preamble. -function Base.getproperty(p::PrefetchingBackend, name::Symbol) - if name in (:inner_backend, :pending, :next_start) - return getfield(p, name) - elseif name == :buffer_fts - @warn "`buffer_fts` is an inner auxiliary field touched in a hot loop by a separate task. " * - "Mutating it manually might lead to undefined behavior. It is recommended not to modify it." - return getfield(p, name) - else - return getproperty(getfield(p, :inner_backend), name) - end -end - -Base.length(p::PrefetchingBackend) = length(p.inner_backend) -Base.summary(p::PrefetchingBackend) = string("PrefetchingBackend(", p.inner_backend.start, ", ", p.inner_backend.length, "; pending=", !isnothing(getfield(p, :pending)), ")") - -# Mutate in place rather than constructing a fresh wrapper — keeps the -# `pending`/`buffer_fts`/`next_start` mutable state in exactly one object. -function new_backend(p::PrefetchingBackend, start, length) - setfield!(p, :inner_backend, new_backend(getfield(p, :inner_backend), start, length)) - return p -end - -Adapt.adapt_structure(to, p::PrefetchingBackend) = Adapt.adapt(to, getfield(p, :inner_backend)) - -const PrefetchingFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:PrefetchingBackend} - -function set!(fts::PrefetchingFTS, backend::PrefetchingBackend = fts.backend) - needed_start = getfield(backend, :inner_backend).start - pending = getfield(backend, :pending) - pending_start = getfield(backend, :next_start) - buffer_fts = getfield(backend, :buffer_fts) - - # Cleared up-front so a failed prefetch isn't re-thrown on every later set!. - setfield!(backend, :pending, nothing) - - # Hot path: the pending prefetch already targets `needed_start`. Wait on - # it (typically a no-op — the background load finished while compute was - # running). A failed prefetch demotes to the synchronous load below, so a - # transient I/O error (a brief FS hiccup, a staging race, etc.) doesn't - # kill the simulation. The spawn site has already logged the exception - # with full variable/window context, so we just print a short warning - # here. - hot = !isnothing(pending) && pending_start == needed_start - if hot - try - wait(pending) - catch - m = buffer_fts.backend.metadata - @warn "PrefetchingBackend: pending prefetch failed; falling back to synchronous load" dataset=typeof(m.dataset) variable=m.name - hot = false - end - elseif !isnothing(pending) - # Stale prefetch targets a different window; drain it and swallow - # any failure — we're about to reload from scratch anyway and the - # spawn site has already logged the failure if there was one. - try - wait(pending) - catch - end - end - - if !hot - Nm = length(getfield(backend, :inner_backend)) - buffer_fts.backend = new_backend(buffer_fts.backend, needed_start, Nm) - set!(buffer_fts) - end - - copyto!(parent(fts.data), parent(buffer_fts.data)) - - # Time-indexing-aware next-window prediction: `time_index` wraps - # via mod1 for Cyclical and clamps to Nt for Linear/Clamp. The next - # reload fires the first time `n₂ = n₁ + 1` falls outside the current - # window, which happens when `n₁` hits the LAST in-memory index, so - # the new `start` is `needed_start + Nm - 1`, not `needed_start + Nm` - # (`update_field_time_series!` sets `start = n₁`). Passing `Nm` - # instead of `Nm + 1` yields that prediction and keeps every reload - # on the HOT path. - Nm = length(getfield(backend, :inner_backend)) - Nt = length(fts.times) - new_next = time_index(buffer_fts.backend, fts.time_indexing, Nt, Nm) - - # Linear/Clamp at end-of-data: window can't advance, no prefetch. - if new_next == needed_start - setfield!(backend, :next_start, 0) - return nothing - end - - buffer_fts.backend = new_backend(buffer_fts.backend, new_next, Nm) - setfield!(backend, :next_start, new_next) - # Worker-side @error logs context (spawn site is gone by `wait` time); rethrow preserves the original. - setfield!(backend, :pending, Threads.@spawn begin - try - set!(buffer_fts) - catch e - m = buffer_fts.backend.metadata - @error "PrefetchingBackend: prefetch task failed" dataset=typeof(m.dataset) variable=m.name window=(new_next, new_next + Nm - 1) exception=(e, catch_backtrace()) - rethrow() - end - end) - - return nothing -end diff --git a/src/DataWrangling/restoring.jl b/src/DataWrangling/restoring.jl index 0d786221d..69e3cf18f 100644 --- a/src/DataWrangling/restoring.jl +++ b/src/DataWrangling/restoring.jl @@ -189,9 +189,6 @@ Keyword Arguments - `cache_inpainted_data`: If `true`, the data is cached to disk after inpainting for later retrieving. Default: `true`. - -- `prefetch`: If `true`, hide the next reload's I/O behind compute via a background `Threads.@spawn` task. - Intended for long-lived FTSes; short-lived ones leak one prefetch task. Default: `false`. """ function DatasetRestoring(metadata::Metadata, arch_or_grid = CPU(); @@ -200,8 +197,7 @@ function DatasetRestoring(metadata::Metadata, time_indices_in_memory = default_time_indices_in_memory(metadata), time_indexing = Cyclical(), inpainting = NearestNeighborInpainting(Inf), - cache_inpainted_data = true, - prefetch = false) + cache_inpainted_data = true) download_dataset(metadata) @@ -209,8 +205,7 @@ function DatasetRestoring(metadata::Metadata, time_indices_in_memory, time_indexing, inpainting, - cache_inpainted_data, - prefetch) + cache_inpainted_data) arch = architecture(fts) mask = on_architecture(arch, mask) diff --git a/test/test_checkpointer.jl b/test/test_checkpointer.jl index 236800ae7..031f83711 100644 --- a/test/test_checkpointer.jl +++ b/test/test_checkpointer.jl @@ -109,63 +109,5 @@ using Oceananigans.OutputWriters: Checkpointer # Cleanup rm.(glob("$(prefix)_iteration*.jld2"), force=true) - - # ---- Same workflow but with prefetch=true on the JRA55 atmosphere. - # Verifies that (1) the JLD2 checkpointer doesn't choke on the - # PrefetchingBackend's mutable state and (2) the restored model - # produces the same data as the reference (i.e. the cold-path - # fallback on restart correctly re-prefetches). With nthreads()=1 - # the prefetch becomes a no-op, which still exercises the - # serialisation and cold-fallback logic. - @info "Testing EarthSystemModel checkpointing with prefetch=true on $A" - - function make_coupled_model_prefetch(grid) - @inline hi(λ, φ) = φ > 70 || φ < -70 - - ocean = ocean_simulation(grid, closure=nothing) - set!(ocean.model, T=20, S=35, u=0.01, v=-0.005) - sea_ice = sea_ice_simulation(grid, ocean) - set!(sea_ice.model, h=hi, ℵ=hi) - - atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4, prefetch=true) - - return OceanSeaIceModel(ocean, sea_ice; atmosphere) - end - - # Reference run with prefetch - model = make_coupled_model_prefetch(grid) - run!(Simulation(model, Δt=60, stop_iteration=3)) - run!(Simulation(model, Δt=60, stop_iteration=6)) - - ref_T = Array(interior(model.ocean.model.tracers.T)) - ref_h = Array(interior(model.sea_ice.model.ice_thickness)) - ref_time = model.clock.time - - # Checkpointed run with prefetch - model = make_coupled_model_prefetch(grid) - simulation = Simulation(model, Δt=60, stop_iteration=3) - prefix_pf = "osm_checkpointer_prefetch_test_$(typeof(arch))" - simulation.output_writers[:checkpointer] = Checkpointer(simulation.model; - schedule = IterationInterval(3), - prefix = prefix_pf) - run!(simulation) - @test isfile("$(prefix_pf)_iteration3.jld2") # JLD2 didn't choke on prefetch state - - model = make_coupled_model_prefetch(grid) - simulation = Simulation(model, Δt=60, stop_iteration=6) - simulation.output_writers[:checkpointer] = Checkpointer(model; - schedule = IterationInterval(3), - prefix = prefix_pf) - set!(simulation; checkpoint=:latest) - set!(simulation; iteration=3) - run!(simulation) - - T = Array(interior(model.ocean.model.tracers.T)) - h = Array(interior(model.sea_ice.model.ice_thickness)) - @test T ≈ ref_T rtol=1e-13 - @test h ≈ ref_h rtol=1e-13 - @test model.clock.time == ref_time - - rm.(glob("$(prefix_pf)_iteration*.jld2"), force=true) end end diff --git a/test/test_jra55.jl b/test/test_jra55.jl index 97d63df19..9384f2c43 100644 --- a/test/test_jra55.jl +++ b/test/test_jra55.jl @@ -56,44 +56,6 @@ using NumericalEarth.DataWrangling: compute_native_date_range f₁′ = view(parent(netcdf_JRA55_fts), :, :, 1, 4) f₁′ = Array(f₁′) @test f₁ == f₁′ - - @info "Testing PrefetchingBackend on $A for $test_name..." - # Build a reference (cold) FTS and a prefetching FTS over the same - # window, then drive each through several reloads. After every - # reload the parent data of the prefetching FTS must be byte- - # identical to the reference. The first reload exercises the cold - # fallback (no prior prefetch); subsequent reloads exercise the - # hot path; the wrap from `Nt-3..Nt` back to `1..Nb` exercises the - # cyclical prefetch logic (`mod1(start+Nm, Nt)`). - ref_fts = FieldTimeSeries(Metadata(test_name; dataset=JRA55.RepeatYearJRA55()), arch; - time_indices_in_memory=Nb) - pf_fts = FieldTimeSeries(Metadata(test_name; dataset=JRA55.RepeatYearJRA55()), arch; - time_indices_in_memory=Nb, prefetch=true) - - @test pf_fts.backend isa NumericalEarth.DataWrangling.PrefetchingBackend - @test parent(pf_fts.data) == parent(ref_fts.data) # cold load alignment - @test pf_fts.backend.next_start == Nb # next prefetch scheduled - - # Reload sequence mirrors what `update_field_time_series!` - # produces at run time — a slide of `Nb - 1` per reload, because - # the first `n₂ = n₁ + 1` outside the window triggers it and the - # new `start` is set to `n₁` (the last in-memory index of the - # previous window): - # * Nb, 2Nb - 1 → hot-path advances; each matches the - # predictor `start + Nm - 1` - # * Nt - Nb + 1 → arbitrary jump (cold-path); schedules a - # prefetch whose window crosses the end of - # times, exercising `mod1(start+Nm-1, Nt)` - # * Nt → consumes that wrapped prefetch (hot - # path) at the very end of the cycle - for next_start in (Nb, 2Nb - 1, Nt - Nb + 1, Nt) - ref_fts.backend = Oceananigans.OutputReaders.new_backend(ref_fts.backend, next_start, Nb) - pf_fts.backend = Oceananigans.OutputReaders.new_backend(pf_fts.backend, next_start, Nb) - set!(ref_fts) - set!(pf_fts) - @test parent(pf_fts.data) == parent(ref_fts.data) - @test pf_fts.backend.next_start == mod1(next_start + Nb - 1, Nt) - end end @info "Testing Field(::JRA55Metadatum) on $A..." From ab8aa12ad41789ca5308ded6d12eb691da90fe6a Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Thu, 23 Apr 2026 12:18:43 +0200 Subject: [PATCH 21/44] remove JRA55NetCDFBackend --- src/DataWrangling/JRA55/JRA55.jl | 2 +- .../JRA55/JRA55_field_time_series.jl | 15 --------------- src/NumericalEarth.jl | 2 -- 3 files changed, 1 insertion(+), 18 deletions(-) diff --git a/src/DataWrangling/JRA55/JRA55.jl b/src/DataWrangling/JRA55/JRA55.jl index 1428580b4..2ee0de14b 100644 --- a/src/DataWrangling/JRA55/JRA55.jl +++ b/src/DataWrangling/JRA55/JRA55.jl @@ -1,6 +1,6 @@ module JRA55 -export JRA55PrescribedAtmosphere, RepeatYearJRA55, MultiYearJRA55, JRA55NetCDFBackend, JRA55FieldTimeSeries +export JRA55PrescribedAtmosphere, RepeatYearJRA55, MultiYearJRA55, JRA55FieldTimeSeries using Oceananigans using Oceananigans.Units diff --git a/src/DataWrangling/JRA55/JRA55_field_time_series.jl b/src/DataWrangling/JRA55/JRA55_field_time_series.jl index f104348b2..621bcddbd 100644 --- a/src/DataWrangling/JRA55/JRA55_field_time_series.jl +++ b/src/DataWrangling/JRA55/JRA55_field_time_series.jl @@ -92,21 +92,6 @@ function JRA55FieldTimeSeries(args...; kwargs...) """) end -""" - JRA55NetCDFBackend(length [, metadata]) - JRA55NetCDFBackend(start, length, metadata) - -Backwards-compatible shorthand for a `DatasetBackend` configured for -JRA55-style chunked NetCDF input: multiple time instances per file and no -inpainting (`inpainting = nothing`). Returns a `DatasetBackend` whose -metadata-driven `set!` dispatches to the JRA55 multi-year or repeat-year -methods defined below. -""" -JRA55NetCDFBackend(length) = DatasetBackend(length, nothing; inpainting=nothing) -JRA55NetCDFBackend(length, metadata::Metadata) = DatasetBackend(length, metadata; inpainting=nothing) -JRA55NetCDFBackend(start::Integer, length::Integer) = DatasetBackend(start, length, nothing; inpainting=nothing) -JRA55NetCDFBackend(start::Integer, length::Integer, metadata) = DatasetBackend(start, length, metadata; inpainting=nothing) - const JRA55NetCDFFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:DatasetBackend{<:Any, <:Any, <:Any, <:JRA55Metadata}} const JRA55NetCDFFTSRepeatYear = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:DatasetBackend{<:Any, <:Any, <:Any, <:Metadata{<:RepeatYearJRA55}}} const JRA55NetCDFFTSMultipleYears = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:DatasetBackend{<:Any, <:Any, <:Any, <:Metadata{<:MultiYearJRA55}}} diff --git a/src/NumericalEarth.jl b/src/NumericalEarth.jl index 9dd134f19..394cb87b5 100644 --- a/src/NumericalEarth.jl +++ b/src/NumericalEarth.jl @@ -30,7 +30,6 @@ export os_papa_prescribed_fluxes, os_papa_prescribed_flux_boundary_conditions, OSPapaHourly, - JRA55NetCDFBackend, regrid_bathymetry, Metadata, Metadatum, @@ -128,7 +127,6 @@ using NumericalEarth.DataWrangling.EN4 using NumericalEarth.DataWrangling.ORCA using NumericalEarth.DataWrangling.WOA using NumericalEarth.DataWrangling.JRA55 -using NumericalEarth.DataWrangling.JRA55: JRA55NetCDFBackend using NumericalEarth.DataWrangling.OSPapa using PrecompileTools: @setup_workload, @compile_workload From 8dea43e59865b6628d6971c5a2e4a1e6277a47ec Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Thu, 23 Apr 2026 14:51:01 +0200 Subject: [PATCH 22/44] make sure we download corrupted files again --- test/download_utils.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/download_utils.jl b/test/download_utils.jl index 23a98118e..650508c9e 100644 --- a/test/download_utils.jl +++ b/test/download_utils.jl @@ -10,11 +10,12 @@ function emit_ci_warning(title, message) end function download_from_artifacts(filepath::AbstractString) - if !isfile(filepath) - filename = basename(filepath) - fallback_url = ARTIFACTS_BASE_URL * filename - @info "Downloading $filename from NumericalEarthArtifacts fallback..." - Downloads.download(fallback_url, filepath) + filename = basename(filepath) + fallback_url = ARTIFACTS_BASE_URL * filename + @info "Downloading $filename from NumericalEarthArtifacts fallback..." + mktemp(dirname(filepath)) do tmppath + Downloads.download(fallback_url, tmppath) + mv(tmppath, filepath; force=true) end end From f472c3d50ea468040d0694076d79f26ad2826082 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Thu, 23 Apr 2026 15:50:33 +0200 Subject: [PATCH 23/44] try like this --- .github/workflows/ci.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5c1c40922..49ae573ce 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -108,6 +108,8 @@ jobs: env: GPU_TEST: "true" OPENBLAS_NUM_THREADS: 1 + # Keep workers alive past ParallelTestRunner's default 3800 MB recycle threshold + JULIA_TEST_MAXRSS_MB: "7000" # Redirect temp to the workspace volume (bind-mounted from the # host's EBS disk) to avoid filling the container's overlay filesystem with # large data downloads and compiled artifacts. @@ -136,7 +138,7 @@ jobs: run: | # Run tests in verbose mode TEST_ARGS=(--verbose) - TEST_ARGS+=(--jobs=$(($(nproc) - 1))) + TEST_ARGS+=(--jobs=$(($(nproc) - 2))) echo "runtest_test_args=${TEST_ARGS[@]}" | tee -a "${GITHUB_ENV}" - name: Update registry shell: julia --color=yes {0} @@ -147,7 +149,7 @@ jobs: Pkg.Registry.update() - name: Run tests run: | - earlyoom -m 3 -s 100 -r 300 --prefer 'julia' & + earlyoom -m 3 -s 100 -r 300 & julia --project --color=yes --check-bounds=yes -e ' using Pkg; Pkg.test(; coverage=true, From af301f4810e9fabf57dbf4b040dcb77a275ee11d Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Thu, 23 Apr 2026 17:43:24 +0200 Subject: [PATCH 24/44] fix downloader --- test/download_utils.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/download_utils.jl b/test/download_utils.jl index 650508c9e..469cba410 100644 --- a/test/download_utils.jl +++ b/test/download_utils.jl @@ -13,7 +13,8 @@ function download_from_artifacts(filepath::AbstractString) filename = basename(filepath) fallback_url = ARTIFACTS_BASE_URL * filename @info "Downloading $filename from NumericalEarthArtifacts fallback..." - mktemp(dirname(filepath)) do tmppath + mktemp(dirname(filepath)) do tmppath, tmpio + close(tmpio) Downloads.download(fallback_url, tmppath) mv(tmppath, filepath; force=true) end From c03a97e0beb9d13404856b054e6e679456c824ce Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Mon, 27 Apr 2026 10:34:04 +0200 Subject: [PATCH 25/44] bounding_box -> region --- src/DataWrangling/JRA55/JRA55_metadata.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/DataWrangling/JRA55/JRA55_metadata.jl b/src/DataWrangling/JRA55/JRA55_metadata.jl index d84042a00..c53bfc940 100644 --- a/src/DataWrangling/JRA55/JRA55_metadata.jl +++ b/src/DataWrangling/JRA55/JRA55_metadata.jl @@ -59,10 +59,10 @@ function native_grid(metadata::JRA55Metadata, arch=CPU(); halo = (3, 3)) longitude, latitude = jra55_native_interfaces(metadata_path(first(metadata))) - bbox = metadata.bounding_box - if !isnothing(bbox) - longitude, Nx = restrict(bbox.longitude, longitude, Nx) - latitude, Ny = restrict(bbox.latitude, latitude, Ny) + region = metadata.region + if !isnothing(region) + longitude, Nx = restrict(region.longitude, longitude, Nx) + latitude, Ny = restrict(region.latitude, latitude, Ny) end grid = LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny), @@ -123,8 +123,8 @@ end # File name generation specific to each Dataset dataset # Note that `RepeatYearJRA55` has only one file associated, so the filename # is independent of the date. Override the multi-date fallback to return a plain String. -metadata_filename(::RepeatYearJRA55, name, date, bounding_box) = "RYF." * JRA55_dataset_variable_names[name] * ".1990_1991.nc" -build_filename(::RepeatYearJRA55, name, dates::AbstractArray, bounding_box) = "RYF." * JRA55_dataset_variable_names[name] * ".1990_1991.nc" +metadata_filename(::RepeatYearJRA55, name, date, region) = "RYF." * JRA55_dataset_variable_names[name] * ".1990_1991.nc" +build_filename(::RepeatYearJRA55, name, dates::AbstractArray, region) = "RYF." * JRA55_dataset_variable_names[name] * ".1990_1991.nc" function metadata_filename(::MultiYearJRA55, name, date, region) shortname = JRA55_dataset_variable_names[name] From c49cb349d167d9a5069f6f5ab65cc3226ae239a1 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Mon, 27 Apr 2026 18:27:48 +0200 Subject: [PATCH 26/44] reenable the single column example --- docs/make.jl | 2 +- examples/single_column_os_papa_simulation.jl | 5 ++--- src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl | 3 ++- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 7bcf21270..f3f332f8d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -29,7 +29,7 @@ mkpath(OUTPUT_DIR) # Set `build_always = false` for long-running examples that should only be built # on pushes to `main`/tags, or when the `build all examples` label is added to a PR. examples = [ - # Example("Single-column ocean simulation", "single_column_os_papa_simulation", true), + Example("Single-column ocean simulation", "single_column_os_papa_simulation", true), Example("One-degree ocean--sea ice simulation", "one_degree_simulation", false), Example("Near-global ocean simulation", "near_global_ocean_simulation", false), Example("Global climate simulation", "global_climate_simulation", false), diff --git a/examples/single_column_os_papa_simulation.jl b/examples/single_column_os_papa_simulation.jl index 5c0439c12..7d8aef240 100644 --- a/examples/single_column_os_papa_simulation.jl +++ b/examples/single_column_os_papa_simulation.jl @@ -63,10 +63,9 @@ set!(ocean.model, T=Metadatum(:temperature, dataset=GLORYSMonthly(), region=col) # We build a `JRA55PrescribedAtmosphere` at the same location as the single-colunm grid # which is based on the JRA55 reanalysis. -atmosphere = JRA55PrescribedAtmosphere(longitude = λ★, - latitude = φ★, +atmosphere = JRA55PrescribedAtmosphere(region = Column(λ★, φ★) end_date = DateTime(1990, 1, 31), # Last day of the simulation - backend = InMemory()) + time_indices_in_memory = 1000) # This builds a representation of the atmosphere on the small grid diff --git a/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl b/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl index 80ec7a814..d35ead509 100644 --- a/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl +++ b/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl @@ -29,12 +29,13 @@ function JRA55PrescribedAtmosphere(architecture = CPU(); time_indexing = Cyclical(), surface_layer_height = 10, # meters include_rivers_and_icebergs = false, + region = nothing, other_kw...) kw = (; time_indexing, time_indices_in_memory) kw = merge(kw, other_kw) - jra55_fts(name) = FieldTimeSeries(Metadata(name; dataset, start_date, end_date, dir), architecture; kw...) + jra55_fts(name) = FieldTimeSeries(Metadata(name; dataset, start_date, end_date, dir, region), architecture; kw...) ua = jra55_fts(:eastward_velocity) va = jra55_fts(:northward_velocity) From a2289da61a7e5d0653b08bc972f9b47bcbff46be Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Mon, 27 Apr 2026 18:47:26 +0200 Subject: [PATCH 27/44] typo --- examples/single_column_os_papa_simulation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/single_column_os_papa_simulation.jl b/examples/single_column_os_papa_simulation.jl index 7d8aef240..680ebb401 100644 --- a/examples/single_column_os_papa_simulation.jl +++ b/examples/single_column_os_papa_simulation.jl @@ -63,7 +63,7 @@ set!(ocean.model, T=Metadatum(:temperature, dataset=GLORYSMonthly(), region=col) # We build a `JRA55PrescribedAtmosphere` at the same location as the single-colunm grid # which is based on the JRA55 reanalysis. -atmosphere = JRA55PrescribedAtmosphere(region = Column(λ★, φ★) +atmosphere = JRA55PrescribedAtmosphere(region = Column(λ★, φ★), end_date = DateTime(1990, 1, 31), # Last day of the simulation time_indices_in_memory = 1000) From 6abba9a280c18533ded424cbc2e4dc35d0d79de7 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Tue, 28 Apr 2026 15:20:37 +0200 Subject: [PATCH 28/44] huge refactor --- src/DataWrangling/DataWrangling.jl | 1 + src/DataWrangling/ECCO/ECCO.jl | 36 +-- src/DataWrangling/ECCO/ECCO_atmosphere.jl | 18 +- .../JRA55/JRA55_field_time_series.jl | 184 ++--------- src/DataWrangling/JRA55/JRA55_metadata.jl | 52 +-- .../JRA55/JRA55_prescribed_atmosphere.jl | 33 +- src/DataWrangling/metadata.jl | 5 + src/DataWrangling/metadata_field.jl | 298 +++++------------- .../metadata_field_time_series.jl | 10 +- src/DataWrangling/set_region_data.jl | 207 ++++++++++++ test/test_column_field.jl | 181 +++++------ test/test_dataset_region.jl | 63 ++++ test/test_jra55_region.jl | 62 ++++ test/test_mangling.jl | 48 +++ test/test_metadata.jl | 49 +++ 15 files changed, 657 insertions(+), 590 deletions(-) create mode 100644 src/DataWrangling/set_region_data.jl create mode 100644 test/test_dataset_region.jl create mode 100644 test/test_jra55_region.jl create mode 100644 test/test_mangling.jl diff --git a/src/DataWrangling/DataWrangling.jl b/src/DataWrangling/DataWrangling.jl index 1606c557d..473d31825 100644 --- a/src/DataWrangling/DataWrangling.jl +++ b/src/DataWrangling/DataWrangling.jl @@ -200,6 +200,7 @@ default_mask_value(dataset) = NaN # Fundamentals include("metadata.jl") +include("set_region_data.jl") include("metadata_field.jl") include("dataset_backend.jl") include("metadata_field_time_series.jl") diff --git a/src/DataWrangling/ECCO/ECCO.jl b/src/DataWrangling/ECCO/ECCO.jl index 576e81014..e30d929ec 100644 --- a/src/DataWrangling/ECCO/ECCO.jl +++ b/src/DataWrangling/ECCO/ECCO.jl @@ -33,8 +33,7 @@ using NumericalEarth.DataWrangling: location, compute_mask, inpaint_mask!, - set_metadata_field!, - extract_column! + set_metadata_field! using KernelAbstractions: @kernel, @index @@ -371,37 +370,4 @@ inpainted_metadata_path(metadata::ECCOMetadatum) = joinpath(metadata.dir, inpain include("ECCO_atmosphere.jl") -##### -##### Column Field for ECCO datasets (which always download globally) -##### - -using Oceananigans.BoundaryConditions: fill_halo_regions! - -const ECCOColumnMetadatum = Metadatum{<:ECCODataset, <:Any, <:Column} - -function Oceananigans.Fields.Field(metadata::ECCOColumnMetadatum, arch=CPU(); - inpainting = default_inpainting(metadata), - mask = nothing, - halo = (3, 3, 3), - cache_inpainted_data = true) - - download_dataset(metadata) - column_grid = native_grid(metadata, arch; halo) - - # Build a full-grid Field without a region to load the global data - global_metadatum = Metadatum(metadata.name; - dataset = metadata.dataset, - date = metadata.dates) - - intermediate_field = Field(global_metadatum, arch; inpainting, mask, halo, cache_inpainted_data) - fill_halo_regions!(intermediate_field) - - # Extract the column - _, _, LZ = location(metadata) - column_field = Field{Nothing, Nothing, LZ}(column_grid) - extract_column!(column_field, intermediate_field, metadata.region) - - return column_field -end - end # Module diff --git a/src/DataWrangling/ECCO/ECCO_atmosphere.jl b/src/DataWrangling/ECCO/ECCO_atmosphere.jl index 56813ca4e..b2d9a3d34 100644 --- a/src/DataWrangling/ECCO/ECCO_atmosphere.jl +++ b/src/DataWrangling/ECCO/ECCO_atmosphere.jl @@ -34,16 +34,16 @@ function ECCOPrescribedAtmosphere(architecture = CPU(), FT = Float32; kw = (; time_indexing, time_indices_in_memory) kw = merge(kw, other_kw) - ecco_fts(name) = FieldTimeSeries(Metadata(name; dataset, start_date, end_date, dir), architecture; kw...) + ECCOFieldTimeSeries(name) = FieldTimeSeries(Metadata(name; dataset, start_date, end_date, dir), architecture; kw...) - ua = ecco_fts(:eastward_wind) - va = ecco_fts(:northward_wind) - Ta = ecco_fts(:air_temperature) - qa = ecco_fts(:air_specific_humidity) - pa = ecco_fts(:sea_level_pressure) - ℐꜜˡʷ = ecco_fts(:downwelling_longwave) - ℐꜜˢʷ = ecco_fts(:downwelling_shortwave) - Fr = ecco_fts(:rain_freshwater_flux) + ua = ECCOFieldTimeSeries(:eastward_wind) + va = ECCOFieldTimeSeries(:northward_wind) + Ta = ECCOFieldTimeSeries(:air_temperature) + qa = ECCOFieldTimeSeries(:air_specific_humidity) + pa = ECCOFieldTimeSeries(:sea_level_pressure) + ℐꜜˡʷ = ECCOFieldTimeSeries(:downwelling_longwave) + ℐꜜˢʷ = ECCOFieldTimeSeries(:downwelling_shortwave) + Fr = ECCOFieldTimeSeries(:rain_freshwater_flux) auxiliary_freshwater_flux = nothing freshwater_flux = (; rain = Fr) diff --git a/src/DataWrangling/JRA55/JRA55_field_time_series.jl b/src/DataWrangling/JRA55/JRA55_field_time_series.jl index 621bcddbd..efeedc6b0 100644 --- a/src/DataWrangling/JRA55/JRA55_field_time_series.jl +++ b/src/DataWrangling/JRA55/JRA55_field_time_series.jl @@ -1,96 +1,12 @@ using NumericalEarth.DataWrangling: all_dates, native_times using NumericalEarth.DataWrangling: compute_native_date_range +using NumericalEarth.DataWrangling: set_region_data! using Oceananigans.Grids: AbstractGrid using Oceananigans.OutputReaders: PartlyInMemory using Adapt import NumericalEarth.DataWrangling: retrieve_data -compute_bounding_nodes(::Nothing, ::Nothing, LH, hnodes) = nothing -compute_bounding_nodes(bounds, ::Nothing, LH, hnodes) = bounds - -function compute_bounding_nodes(x::Number, ::Nothing, LH, hnodes) - ϵ = convert(typeof(x), 0.001) # arbitrary? - return (x - ϵ, x + ϵ) -end - -# TODO: remove the allowscalar -function compute_bounding_nodes(::Nothing, grid, LH, hnodes) - hg = hnodes(grid, LH()) - h₁ = @allowscalar minimum(hg) - h₂ = @allowscalar maximum(hg) - return h₁, h₂ -end - -function compute_bounding_indices(::Nothing, hc) - Nh = length(hc) - return 1, Nh -end - -function compute_bounding_indices(bounds::Tuple, hc) - h₁, h₂ = bounds - Nh = length(hc) - - # The following should work. If ᵒ are the extrema of nodes we want to - # interpolate to, and the following is a sketch of the JRA55 native grid, - # - # 1 2 3 4 5 - # | | | | | | - # | x ᵒ | x | x | x ᵒ | x | - # | | | | | | - # 1 2 3 4 5 6 - # - # then for example, we should find that (iᵢ, i₂) = (1, 5). - # So we want to reduce the first index by one, and limit them - # both by the available data. There could be some mismatch due - # to the use of different coordinate systems (ie whether λ ∈ (0, 360) - # which we may also need to handle separately. - i₁ = searchsortedfirst(hc, h₁) - i₂ = searchsortedfirst(hc, h₂) - i₁ = max(1, i₁ - 1) - i₂ = min(Nh, i₂) - - return i₁, i₂ -end - -infer_longitudinal_topology(::Nothing) = Periodic - -function infer_longitudinal_topology(λbounds) - λ₁, λ₂ = λbounds - TX = λ₂ - λ₁ ≈ 360 ? Periodic : Bounded - return TX -end - -function compute_bounding_indices(longitude, latitude, grid, LX, LY, λc, φc) - λbounds = compute_bounding_nodes(longitude, grid, LX, λnodes) - φbounds = compute_bounding_nodes(latitude, grid, LY, φnodes) - - i₁, i₂ = compute_bounding_indices(λbounds, λc) - j₁, j₂ = compute_bounding_indices(φbounds, φc) - TX = infer_longitudinal_topology(λbounds) - - return i₁, i₂, j₁, j₂, TX -end - -# Migration shim — `JRA55FieldTimeSeries` was removed in favor of the -# generic `FieldTimeSeries(::Metadata)` path. The shim throws a clear -# error pointing at the new API; can be deleted once downstream callers -# (ClimaOcean.jl, user scripts) have migrated. -function JRA55FieldTimeSeries(args...; kwargs...) - error(""" - `JRA55FieldTimeSeries` was removed; JRA55 now uses the generic - dataset API. Migrate as: - - FieldTimeSeries(Metadata(:variable_name; dataset=RepeatYearJRA55(), - start_date, end_date, dir), - architecture; - time_indices_in_memory = N) - - The `InMemory()` backend is no longer supported for JRA55; pass - `time_indices_in_memory = length(metadata)` to keep the whole - series in memory. - """) -end const JRA55NetCDFFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:DatasetBackend{<:Any, <:Any, <:Any, <:JRA55Metadata}} const JRA55NetCDFFTSRepeatYear = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:DatasetBackend{<:Any, <:Any, <:Any, <:Metadata{<:RepeatYearJRA55}}} @@ -169,72 +85,48 @@ end # - ds["lat_bnds"]: bounding latitudes between which variables are averaged # - ds[shortname]: the variable data -# Simple case, only one file per variable, no need to deal with multiple files -function set!(fts::JRA55NetCDFFTSRepeatYear, backend=fts.backend) +# Split at the wrap point if `nn` cycles past 1 — DiskArrays requires sorted indices. +function jra55_read_data(ds, name, i, j, nn) + if issorted(nn) + return ds[name][i, j, nn] + else + m = findfirst(==(1), nn) + d1 = ds[name][i, j, nn[1:m-1]] + d2 = ds[name][i, j, nn[m:end]] + return cat(d1, d2; dims=3) + end +end +function set!(fts::JRA55NetCDFFTSRepeatYear, backend=fts.backend) metadata = backend.metadata + ds = Dataset(joinpath(metadata.dir, metadata.filename)) - filename = metadata.filename - path = joinpath(metadata.dir, filename) - ds = Dataset(path) - - # Nodes at the variable location λc = ds["lon"][:] φc = ds["lat"][:] - LX, LY, LZ = location(fts) - i₁, i₂, j₁, j₂, TX = compute_bounding_indices(nothing, nothing, fts.grid, LX, LY, λc, φc) - - nn = time_indices(fts) - nn = collect(nn) - name = dataset_variable_name(fts.backend.metadata) - - if issorted(nn) - data = ds[name][i₁:i₂, j₁:j₂, nn] - else - # The time indices may be cycling past 1; eg ti = [6, 7, 8, 1]. - # However, DiskArrays does not seem to support loading data with unsorted - # indices. So to handle this, we load the data in chunks, where each chunk's - # indices are sorted, and then glue the data together. - m = findfirst(n -> n == 1, nn) - n1 = nn[1:m-1] - n2 = nn[m:end] - - data1 = ds[name][i₁:i₂, j₁:j₂, n1] - data2 = ds[name][i₁:i₂, j₁:j₂, n2] - data = cat(data1, data2, dims=3) - end + nn = collect(time_indices(fts)) + name = dataset_variable_name(metadata) + raw = jra55_read_data(ds, name, :, :, nn) close(ds) + full_data = reshape(raw, length(λc), length(φc), 1, length(nn)) - copyto!(interior(fts, :, :, 1, :), data) + set_region_data!(fts, full_data, λc, φc, metadata) fill_halo_regions!(fts) - return nothing end -# Multi-year case: one file per calendar year. Match each in-memory slot's -# date to the file's time axis by (Y, M, D, H, min) components rather than -# by seconds-since-start. JRA55 files use `DateTimeNoLeap` (365-day) -# internally; doing seconds arithmetic against a Gregorian -# `metadata.dates` would diverge by a multiple of 86400 after every leap -# day passed and silently drop the affected slots. Component matching -# sidesteps this entirely. Feb 29 of leap years has no entry in the file -# and is skipped here too. +# JRA55 multi-year files use the no-leap calendar; matching by date components +# sidesteps the seconds-since-start drift across leap days. function set!(fts::JRA55NetCDFFTSMultipleYears, backend=fts.backend) - metadata = backend.metadata name = dataset_variable_name(metadata) ftsn = collect(time_indices(fts)) slot_dates = metadata.dates[ftsn] - needed_files = unique(getfilename(metadata.filename, n) for n in ftsn) for file in needed_files - - path = joinpath(metadata.dir, file) - ds = Dataset(path) - + ds = Dataset(joinpath(metadata.dir, file)) file_dates = ds["time"][:] nn = Int[] @@ -250,36 +142,14 @@ function set!(fts::JRA55NetCDFFTSMultipleYears, backend=fts.backend) if !isempty(nn) λc = ds["lon"][:] φc = ds["lat"][:] - LX, LY, LZ = location(fts) - i₁, i₂, j₁, j₂, TX = compute_bounding_indices(nothing, nothing, fts.grid, LX, LY, λc, φc) - - if issorted(nn) - data = ds[name][i₁:i₂, j₁:j₂, nn] - else - # Defensive: per-file `nn` is normally sorted because we - # iterate `slot_dates` in `ftsn` order and a single file - # holds a single contiguous year, but DiskArrays requires - # sorted indices so we split at the wrap if it occurs. - m = findfirst(n -> n == 1, nn) - n1 = nn[1:m-1] - n2 = nn[m:end] - - data1 = ds[name][i₁:i₂, j₁:j₂, n1] - data2 = ds[name][i₁:i₂, j₁:j₂, n2] - data = cat(data1, data2, dims=3) - end - - close(ds) - - for n in 1:length(nn) - copyto!(interior(fts, :, :, 1, ftsn_loc[n]), data[:, :, n]) - end - else - close(ds) + raw = jra55_read_data(ds, name, :, :, nn) + full_data = reshape(raw, length(λc), length(φc), 1, length(nn)) + set_region_data!(fts, full_data, λc, φc, metadata; slot_indices = ftsn_loc) end + close(ds) end fill_halo_regions!(fts) - return nothing end + diff --git a/src/DataWrangling/JRA55/JRA55_metadata.jl b/src/DataWrangling/JRA55/JRA55_metadata.jl index c53bfc940..29bbdab8e 100644 --- a/src/DataWrangling/JRA55/JRA55_metadata.jl +++ b/src/DataWrangling/JRA55/JRA55_metadata.jl @@ -3,7 +3,6 @@ using Dates using Downloads using Oceananigans.DistributedComputations -using Oceananigans.Grids: pop_flat_elements using NumericalEarth.DataWrangling using NumericalEarth.DataWrangling: Metadata, metadata_path, download_progress, AnyDateTime, DatasetBackend @@ -13,19 +12,20 @@ import Oceananigans.Fields: set! import Base import Oceananigans.Fields: set!, location -import NumericalEarth.DataWrangling: all_dates, - metadata_filename, - build_filename, - download_dataset, - default_download_directory, +import NumericalEarth.DataWrangling: all_dates, + metadata_filename, + build_filename, + download_dataset, + default_download_directory, dataset_variable_name, - available_variables, + available_variables, default_inpainting, getfilename, - z_interfaces, - native_grid, - longitude_interfaces, - latitude_interfaces + longitude_interfaces, + latitude_interfaces, + longitude_name, + latitude_name, + is_three_dimensional abstract type JRA55Dataset end @@ -48,30 +48,13 @@ function Base.size(::JRA55Dataset, variable) end end -longitude_interfaces(::JRA55Metadata) = (0, 360) -latitude_interfaces(::JRA55Metadata) = (-90, 90) +longitude_interfaces(md::JRA55Metadata) = first(jra55_native_interfaces(metadata_path(first(md)))) +latitude_interfaces(md::JRA55Metadata) = last(jra55_native_interfaces(metadata_path(first(md)))) -function native_grid(metadata::JRA55Metadata, arch=CPU(); halo = (3, 3)) - Nx, Ny, Nz, _ = size(metadata) - - FT = eltype(metadata) - halo = pop_flat_elements(halo, (Periodic, Bounded, Flat)) - - longitude, latitude = jra55_native_interfaces(metadata_path(first(metadata))) - - region = metadata.region - if !isnothing(region) - longitude, Nx = restrict(region.longitude, longitude, Nx) - latitude, Ny = restrict(region.latitude, latitude, Ny) - end - - grid = LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny), - halo, longitude, latitude, - topology = (Periodic, Bounded, Flat)) - - return grid -end +longitude_name(::JRA55Metadata) = "lon" +latitude_name(::JRA55Metadata) = "lat" +# `lon_bnds`/`lat_bnds` hold only the left/lower interface; append the trailing one. function jra55_native_interfaces(path) ds = Dataset(path) λn = Array{Float64}(ds["lon_bnds"][1, :]) @@ -82,14 +65,11 @@ function jra55_native_interfaces(path) # so we need to append the trailing interface. push!(λn, λn[1] + 360) push!(φn, 90) - return λn, φn end -# JRA55 is a spatially 2D dataset is_three_dimensional(data::JRA55Metadata) = false -# Never inpaint JRA55 default_inpainting(::JRA55Metadata) = nothing # The whole range of dates in the different dataset datasets diff --git a/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl b/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl index d35ead509..69fd5886d 100644 --- a/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl +++ b/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl @@ -13,12 +13,13 @@ JRA55PrescribedAtmosphere(arch::Distributed; kw...) = time_indexing = Cyclical(), surface_layer_height = 10, # meters include_rivers_and_icebergs = false, + region = nothing, other_kw...) -Return a [`PrescribedAtmosphere`](@ref) representing JRA55 reanalysis data. -Each atmospheric field is constructed via `FieldTimeSeries(::JRA55Metadata)`, -which uses a `DatasetBackend` parameterised by JRA55 metadata so that the -JRA55-specific `set!` (chunked-yearly NetCDF) is dispatched. +Return a [`PrescribedAtmosphere`](@ref) representing JRA55 reanalysis data. Each atmospheric field is constructed via +`FieldTimeSeries(::JRA55Metadata)`, which uses a `DatasetBackend` parameterised by JRA55 metadata so that the JRA55-specific +`set!` (chunked-yearly NetCDF) is dispatched. +The `region` keyword restricts the atmosphere to a sub-domain of the global JRA55 grid. """ function JRA55PrescribedAtmosphere(architecture = CPU(); dataset = RepeatYearJRA55(), @@ -35,17 +36,17 @@ function JRA55PrescribedAtmosphere(architecture = CPU(); kw = (; time_indexing, time_indices_in_memory) kw = merge(kw, other_kw) - jra55_fts(name) = FieldTimeSeries(Metadata(name; dataset, start_date, end_date, dir, region), architecture; kw...) + JRA55FieldTimeSeries(name) = FieldTimeSeries(Metadata(name; dataset, start_date, end_date, dir, region), architecture; kw...) - ua = jra55_fts(:eastward_velocity) - va = jra55_fts(:northward_velocity) - Ta = jra55_fts(:temperature) - qa = jra55_fts(:specific_humidity) - pa = jra55_fts(:sea_level_pressure) - Fra = jra55_fts(:rain_freshwater_flux) - Fsn = jra55_fts(:snow_freshwater_flux) - ℐꜜˡʷ = jra55_fts(:downwelling_longwave_radiation) - ℐꜜˢʷ = jra55_fts(:downwelling_shortwave_radiation) + ua = JRA55FieldTimeSeries(:eastward_velocity) + va = JRA55FieldTimeSeries(:northward_velocity) + Ta = JRA55FieldTimeSeries(:temperature) + qa = JRA55FieldTimeSeries(:specific_humidity) + pa = JRA55FieldTimeSeries(:sea_level_pressure) + Fra = JRA55FieldTimeSeries(:rain_freshwater_flux) + Fsn = JRA55FieldTimeSeries(:snow_freshwater_flux) + ℐꜜˡʷ = JRA55FieldTimeSeries(:downwelling_longwave_radiation) + ℐꜜˢʷ = JRA55FieldTimeSeries(:downwelling_shortwave_radiation) freshwater_flux = (rain = Fra, snow = Fsn) @@ -54,8 +55,8 @@ function JRA55PrescribedAtmosphere(architecture = CPU(); # frequency than the rest of the JRA55 data. We use the # PrescribedAtmosphere `auxiliary_freshwater_flux` feature for them. if include_rivers_and_icebergs - Fri = jra55_fts(:river_freshwater_flux) - Fic = jra55_fts(:iceberg_freshwater_flux) + Fri = JRA55FieldTimeSeries(:river_freshwater_flux) + Fic = JRA55FieldTimeSeries(:iceberg_freshwater_flux) auxiliary_freshwater_flux = (rivers = Fri, icebergs = Fic) else auxiliary_freshwater_flux = nothing diff --git a/src/DataWrangling/metadata.jl b/src/DataWrangling/metadata.jl index cb1153e16..d8673b65f 100644 --- a/src/DataWrangling/metadata.jl +++ b/src/DataWrangling/metadata.jl @@ -86,6 +86,11 @@ Metadata(name, dataset, dates, region, dir) = Metadata(name, dataset, dates, reg is_three_dimensional(::Metadata) = true z_interfaces(md::Metadata) = z_interfaces(md.dataset) + +# NetCDF coordinate-variable names. Default follows CF standard; datasets +# whose files use different names (e.g. JRA55 uses `lon`/`lat`) override. +longitude_name(::Metadata) = "longitude" +latitude_name(::Metadata) = "latitude" longitude_interfaces(md::Metadata) = longitude_interfaces(md.dataset) latitude_interfaces(md::Metadata) = latitude_interfaces(md.dataset) diff --git a/src/DataWrangling/metadata_field.jl b/src/DataWrangling/metadata_field.jl index 6a7c60b4d..73eafd70f 100644 --- a/src/DataWrangling/metadata_field.jl +++ b/src/DataWrangling/metadata_field.jl @@ -2,9 +2,8 @@ using NCDatasets using JLD2 using NumericalEarth.InitialConditions: interpolate! using Statistics: median -using Oceananigans.Grids: λnodes, φnodes +using Oceananigans.Grids: λnodes, φnodes, Periodic, Bounded using Oceananigans.Architectures: on_architecture -using Oceananigans.Fields: fractional_x_index, fractional_y_index import Oceananigans.Fields: set!, Field, location @@ -23,14 +22,21 @@ restrict_location((LX, LY, LZ), ::Column) = (Nothing, Nothing, LZ) ##### restrict(::Nothing, interfaces, N) = interfaces, N +restrict(::Nothing, interfaces::NTuple{2,Any}, N) = interfaces, N +restrict(::Nothing, interfaces::AbstractVector, N) = interfaces, N -# TODO support stretched native grids -function restrict(bbox_interfaces, interfaces, N) - extent = interfaces[end] - interfaces[1] - rΔ = bbox_interfaces[2] - bbox_interfaces[1] - rN = round(Int, rΔ / extent * N) - rN = max(rN, 1) # at least one cell - return bbox_interfaces, rN +# Uniform native grid: 2-tuple endpoints expand to a uniform interface range. +function restrict(bbox_interfaces, interfaces::NTuple{2,Any}, N) + full = range(interfaces[1], interfaces[2]; length = N + 1) + return restrict(bbox_interfaces, full, N) +end + +# Stretched native grid: snap outward to the nearest native cell interfaces. +function restrict(bbox_interfaces, interfaces::AbstractVector, N) + i⁻ = max(searchsortedlast(interfaces, bbox_interfaces[1]), 1) + i⁺ = min(searchsortedfirst(interfaces, bbox_interfaces[2]), length(interfaces)) + rN = max(i⁺ - i⁻, 1) + return interfaces[i⁻:i⁺], rN end """ @@ -43,53 +49,62 @@ and a column `RectilinearGrid` for `Column` regions. native_grid(metadata::Metadata, arch=CPU(); halo=(3, 3, 3)) = construct_native_grid(metadata, metadata.region, arch; halo) -# Full global grid (no region restriction) +# 2D-only datasets (surface forcing like JRA55) skip the z dimension. function construct_native_grid(metadata, ::Nothing, arch; halo) - Nx, Ny, Nz, _ = size(metadata) - z = z_interfaces(metadata) FT = eltype(metadata) longitude = longitude_interfaces(metadata) latitude = latitude_interfaces(metadata) + Nx, Ny, Nz = size(metadata) - grid = LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny, Nz), - halo, longitude, latitude, z) - return grid + if is_three_dimensional(metadata) + z = z_interfaces(metadata) + return LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny, Nz), + halo, longitude, latitude, z) + else + return LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny), + halo = halo[1:2], longitude, latitude, + topology = (Periodic, Bounded, Flat)) + end end -# BoundingBox-restricted LatitudeLongitudeGrid function construct_native_grid(metadata, bbox::BoundingBox, arch; halo) - Nx, Ny, Nz, _ = size(metadata) - z = z_interfaces(metadata) FT = eltype(metadata) - longitude = longitude_interfaces(metadata) - latitude = latitude_interfaces(metadata) + native_longitude = longitude_interfaces(metadata) + native_latitude = latitude_interfaces(metadata) - # TODO: can we restrict in `z` as well? - longitude, Nx = restrict(bbox.longitude, longitude, Nx) - latitude, Ny = restrict(bbox.latitude, latitude, Ny) + Nx, Ny, Nz = size(metadata) + longitude, Nx = restrict(bbox.longitude, native_longitude, Nx) + latitude, Ny = restrict(bbox.latitude, native_latitude, Ny) - # Clamp halo so it does not exceed grid size in any dimension - halo = min.(halo, (Nx, Ny, Nz)) + TX = infer_longitudinal_topology(native_longitude, longitude) - grid = LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny, Nz), - halo, longitude, latitude, z) - return grid + if is_three_dimensional(metadata) + z = z_interfaces(metadata) + return LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny, Nz), + halo, longitude, latitude, z, + topology = (TX, Bounded, Bounded)) + else + return LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny), + halo = halo[1:2], longitude, latitude, + topology = (TX, Bounded, Flat)) + end end -# Column RectilinearGrid +# 2D-only datasets collapse to (Flat, Flat, Flat); 3D keep z Bounded. function construct_native_grid(metadata, col::Column, arch; halo) - _, _, Nz, _ = size(metadata) - z = z_interfaces(metadata) FT = eltype(metadata) + x = FT(col.longitude) + y = FT(col.latitude) - grid = RectilinearGrid(arch, FT; - size = Nz, - x = FT(col.longitude), - y = FT(col.latitude), - z, - halo = halo[3], - topology = (Flat, Flat, Bounded)) - return grid + if is_three_dimensional(metadata) + _, _, Nz, _ = size(metadata) + z = z_interfaces(metadata) + return RectilinearGrid(arch, FT; size = Nz, halo = halo[3], + x, y, z, topology = (Flat, Flat, Bounded)) + else + return RectilinearGrid(arch, FT; size = (), halo = (), + x, y, topology = (Flat, Flat, Flat)) + end end """ @@ -131,7 +146,7 @@ end architecture = CPU(), inpainting = default_inpainting(metadata), mask = nothing, - halo = (7, 7, 7), + halo = (3, 3, 3), cache_inpainted_data = true) Return a `Field` on `architecture` described by `metadata` with `halo` size. @@ -148,12 +163,12 @@ function Field(metadata::Metadatum, arch=CPU(); download_dataset(metadata) - # Column regions need special handling: the downloaded file may contain - # more data than a single column (e.g. CopernicusMarine returns a small - # grid around the point). Load onto an intermediate grid from the file's - # actual dimensions, then extract the column. + # Inpainting on a (Flat, Flat, *) column field is meaningless and the + # iterative algorithm doesn't terminate gracefully without horizontal + # neighbours; the NaN-aware bracket-blend in `set_region_data!` handles + # land cells directly. if metadata.region isa Column - return column_field_from_file(metadata, arch; inpainting, mask, halo, cache_inpainted_data) + inpainting = nothing end grid = native_grid(metadata, arch; halo) @@ -245,191 +260,22 @@ function set!(target_field::Field, metadata::Metadatum; kw...) return target_field end -##### -##### Column field construction -##### - -function column_field_from_file(metadata, arch; halo=(3, 3, 3), kw...) - column_grid = native_grid(metadata, arch; halo) - - # Read the file's actual dimensions to build a matching intermediate grid - path = metadata_path(metadata) - ds = Dataset(path) - varname = dataset_variable_name(metadata) - var = ds[varname] - data_size = size(var) - Nx_file, Ny_file = data_size[1], data_size[2] - - # Read coordinate arrays - lon_dimname = NCDatasets.dimnames(var)[1] - lat_dimname = NCDatasets.dimnames(var)[2] - λ = haskey(ds, lon_dimname) ? ds[lon_dimname][:] : ds["longitude"][:] - φ = haskey(ds, lat_dimname) ? ds[lat_dimname][:] : ds["latitude"][:] - close(ds) - - if reversed_latitude_axis(metadata.dataset) - reverse!(φ) - end - - _, _, Nz, _ = size(metadata) - z = z_interfaces(metadata) - FT = eltype(metadata) - - # Build cell interfaces from centers - Δλ = Nx_file > 1 ? λ[2] - λ[1] : FT(1) - λf = range(λ[1] - Δλ/2, stop = λ[end] + Δλ/2, length = Nx_file + 1) - - Δφ = Ny_file > 1 ? φ[2] - φ[1] : FT(1) - φf = range(φ[1] - Δφ/2, stop = φ[end] + Δφ/2, length = Ny_file + 1) - - halo = min.(halo, (Nx_file, Ny_file, Nz)) - - intermediate_grid = LatitudeLongitudeGrid(arch, FT; - size = (Nx_file, Ny_file, Nz), - halo, longitude = λf, latitude = φf, z) - - # Load data onto intermediate grid (no inpainting — columns have no horizontal neighbors) - LX, LY, LZ = dataset_location(metadata.dataset, metadata.name) - intermediate_field = Field{LX, LY, LZ}(intermediate_grid) - - data = retrieve_data(metadata) - set_metadata_field!(intermediate_field, data, metadata) - fill_halo_regions!(intermediate_field) - - # Extract column - _, _, LZ_col = location(metadata) - col_field = Field{Nothing, Nothing, LZ_col}(column_grid) - extract_column!(col_field, intermediate_field, metadata.region) - - return col_field -end - -##### -##### Column extraction utilities -##### - -# Dispatch extraction on interpolation method -function extract_column!(column_field, intermediate_field, col::Column) - extract_column!(column_field, intermediate_field, col, col.interpolation) -end - -function extract_column!(column_field, intermediate_field, col, ::Linear) - grid = intermediate_field.grid - arch = architecture(grid) - LX, LY, LZ = Oceananigans.Fields.location(intermediate_field) - locs = (LX(), LY(), LZ()) - - # Fractional indices (1-based, continuous) - fi = fractional_x_index(col.longitude, locs, grid) - fj = fractional_y_index(col.latitude, locs, grid) - - # Lower-left index and weights - i₁ = clamp(floor(Int, fi), 1, size(grid, 1)) - j₁ = clamp(floor(Int, fj), 1, size(grid, 2)) - i₂ = clamp(i₁ + 1, 1, size(grid, 1)) - j₂ = clamp(j₁ + 1, 1, size(grid, 2)) - - wx = clamp(fi - floor(fi), 0, 1) - wy = clamp(fj - floor(fj), 0, 1) - - launch!(arch, column_field.grid, :z, _bilinear_interpolate_column!, - column_field, intermediate_field, i₁, j₁, i₂, j₂, wx, wy) - - return nothing -end - -@kernel function _bilinear_interpolate_column!(column_field, source, i₁, j₁, i₂, j₂, wx, wy) - k = @index(Global, Linear) - @inbounds begin - v00 = source[i₁, j₁, k] - v10 = source[i₂, j₁, k] - v01 = source[i₁, j₂, k] - v11 = source[i₂, j₂, k] - column_field[1, 1, k] = (1 - wx) * (1 - wy) * v00 + - wx * (1 - wy) * v10 + - (1 - wx) * wy * v01 + - wx * wy * v11 - end -end - -function extract_column!(column_field, intermediate_field, col, ::Nearest) - grid = intermediate_field.grid - arch = architecture(grid) - LX, LY, LZ = Oceananigans.Fields.location(intermediate_field) - locs = (LX(), LY(), LZ()) # fractional index functions expect instances, not types - - # Use Oceananigans' fractional index machinery (handles cyclic longitude etc.) - i★ = round(Int, fractional_x_index(col.longitude, locs, grid)) - j★ = round(Int, fractional_y_index(col.latitude, locs, grid)) - - launch!(arch, column_field.grid, :z, copy_column!, column_field, intermediate_field, i★, j★) - - return nothing -end - -@kernel function copy_column!(column_field, source_field, i★, j★) - k = @index(Global, Linear) - @inbounds column_field[1, 1, k] = source_field[i★, j★, k] -end - -# manglings -struct ShiftSouth end -struct AverageNorthSouth end - -@inline mangle(i, j, data, ::Nothing) = @inbounds data[i, j] -@inline mangle(i, j, data, ::ShiftSouth) = @inbounds data[i, j-1] -@inline mangle(i, j, data, ::AverageNorthSouth) = @inbounds (data[i, j+1] + data[i, j]) / 2 - -@inline mangle(i, j, k, data, ::Nothing) = @inbounds data[i, j, k] -@inline mangle(i, j, k, data, ::ShiftSouth) = @inbounds data[i, j-1, k] -@inline mangle(i, j, k, data, ::AverageNorthSouth) = @inbounds (data[i, j+1, k] + data[i, j, k]) / 2 - function set_metadata_field!(field, data, metadatum) - grid = field.grid - arch = architecture(grid) - - Nx, Ny, Nz = size(metadatum) - mangling = if size(data, 2) == Ny-1 - ShiftSouth() - elseif size(data, 2) == Ny+1 - AverageNorthSouth() - else - nothing - end - - conversion = conversion_units(metadatum) - - if ndims(data) == 2 - _kernel = _set_2d_metadata_field! - spec = :xy - else - _kernel = _set_3d_metadata_field! - spec = :xyz - end - - data = on_architecture(arch, data) - Oceananigans.Utils.launch!(arch, grid, spec, _kernel, field, data, mangling, conversion) - + full_data = ndims(data) == 2 ? reshape(data, size(data, 1), size(data, 2), 1) : data + λc, φc = read_file_coords(metadatum) + set_region_data!(field, full_data, λc, φc, metadatum) return nothing end -@kernel function _set_2d_metadata_field!(field, data, mangling, conversion) - i, j = @index(Global, NTuple) - FT = eltype(field) - d = mangle(i, j, data, mangling) - d = nan_convert_missing(FT, d) - d = convert_units(d, conversion) - @inbounds field[i, j, 1] = d -end - -@kernel function _set_3d_metadata_field!(field, data, mangling, conversion) - i, j, k = @index(Global, NTuple) - FT = eltype(field) - d = mangle(i, j, k, data, mangling) - d = nan_convert_missing(FT, d) - d = convert_units(d, conversion) - - @inbounds field[i, j, k] = d +# Read the lon/lat cell centres from the NetCDF file using the names supplied +# by the dataset's `longitude_name` / `latitude_name` traits. +function read_file_coords(metadatum) + ds = Dataset(metadata_path(metadatum)) + λc = ds[longitude_name(metadatum)][:] + φc = ds[latitude_name(metadatum)][:] + close(ds) + reversed_latitude_axis(metadatum.dataset) && reverse!(φc) + return λc, φc end ##### diff --git a/src/DataWrangling/metadata_field_time_series.jl b/src/DataWrangling/metadata_field_time_series.jl index 23171a0e3..42aa50c5a 100644 --- a/src/DataWrangling/metadata_field_time_series.jl +++ b/src/DataWrangling/metadata_field_time_series.jl @@ -1,7 +1,3 @@ -using Oceananigans.Architectures: AbstractArchitecture, architecture -using Oceananigans.Grids: AbstractGrid -using Oceananigans.Fields: interpolate! - import Oceananigans.OutputReaders: update_field_time_series!, FieldTimeSeries """ @@ -49,8 +45,6 @@ function FieldTimeSeries(metadata::Metadata, grid::AbstractGrid; download_dataset(metadata) - # Detect "the user's grid IS the native grid" structurally - on_native_grid = grid == native_grid(metadata, architecture(grid)) times = native_times(metadata) # Make sure we do not use more indices then the ones available! @@ -59,8 +53,8 @@ function FieldTimeSeries(metadata::Metadata, grid::AbstractGrid; end inpainting isa Int && (inpainting = NearestNeighborInpainting(inpainting)) - is_native = grid == native_grid(metadata) - backend = DatasetBackend(time_indices_in_memory, metadata; on_native_grid=is_native, inpainting, cache_inpainted_data) + on_native_grid = grid == native_grid(metadata, architecture(grid)) + backend = DatasetBackend(time_indices_in_memory, metadata; on_native_grid, inpainting, cache_inpainted_data) loc = LX, LY, LZ = location(metadata) boundary_conditions = FieldBoundaryConditions(grid, instantiate.(loc)) diff --git a/src/DataWrangling/set_region_data.jl b/src/DataWrangling/set_region_data.jl new file mode 100644 index 000000000..c2bd910a4 --- /dev/null +++ b/src/DataWrangling/set_region_data.jl @@ -0,0 +1,207 @@ +using Oceananigans.Utils: launch! +using Oceananigans.Architectures: AbstractArchitecture, architecture +using Oceananigans.Grids: AbstractGrid, Periodic, Bounded, λnodes, φnodes +using Oceananigans.Fields: Field, interior, interpolate! +using GPUArraysCore: @allowscalar + +##### +##### Region helpers shared across dataset backends +##### + +function compute_bounding_nodes(grid, LH, hnodes) + hg = hnodes(grid, LH()) + h₁ = @allowscalar minimum(hg) + h₂ = @allowscalar maximum(hg) + return h₁, h₂ +end + +# `ε` forgives Float32 to Float64 promotion noise so the slice doesn't lose a +# cell at each end when grid centers are compared against file centers. +function compute_bounding_indices(bounds::Tuple, hc) + h₁, h₂ = bounds + Nh = length(hc) + ε = eps(Float32) * max(one(eltype(hc)), abs(h₁), abs(h₂)) + i₁ = max(searchsortedfirst(hc, h₁ - ε), 1) + i₂ = min( searchsortedlast(hc, h₂ + ε), Nh) + return i₁, i₂ +end + +# Periodic only when the restricted span equals the full native span. +function infer_longitudinal_topology(full_longitude, restricted_longitude) + full_span = full_longitude[end] - full_longitude[1] + restricted_span = restricted_longitude[end] - restricted_longitude[1] + return restricted_span ≈ full_span ? Periodic : Bounded +end + +##### +##### Mangling utilities +##### + +struct ShiftSouth end +struct AverageNorthSouth end + +# `mangle(i, j, k, data, mangling)` reads file `data` at metadata-grid index +# `(i, j, k)`, accounting for staggered lat-axis offsets. Used inside the +# region-aware kernel. +@inline mangle(i, j, k, data, ::Nothing) = @inbounds data[i, j, k] +@inline mangle(i, j, k, data, ::ShiftSouth) = @inbounds data[i, max(j - 1, 1), k] +@inline mangle(i, j, k, data, ::AverageNorthSouth) = @inbounds (data[i, j, k] + data[i, j + 1, k]) / 2 + +##### +##### Region-aware filling for Fields and FieldTimeSeries via a single kernel. +##### +##### `read_data(data, i, j, k, region_info, mangling)` is the only access point: it composes the +##### file-axis offset (region) with the lat-axis remap (mangling). All region/mangling combinations +##### go through one kernel that handles NaN + unit conversion in the same pass. +##### + +struct BBoxOffset + di :: Int + dj :: Int +end + +struct ColumnInfo{F, I} + i⁻ :: Int + i⁺ :: Int + j⁻ :: Int + j⁺ :: Int + wx :: F + wy :: F + ℑ :: I +end + +# `region_info` resolves the target's region to a kernel-friendly struct. +region_info(::Nothing, target, λc, φc) = nothing + +function region_info(::BoundingBox, target, λc, φc) + LX, LY, _ = Oceananigans.Fields.location(target) + i₁, _ = compute_bounding_indices(compute_bounding_nodes(target.grid, LX, λnodes), λc) + j₁, _ = compute_bounding_indices(compute_bounding_nodes(target.grid, LY, φnodes), φc) + return BBoxOffset(i₁ - 1, j₁ - 1) +end + +function region_info(col::Column, target, λc, φc) + i⁻, i⁺, wx = bracket_with_weight(λc, col.longitude; period = infer_longitudinal_period(λc)) + j⁻, j⁺, wy = bracket_with_weight(φc, col.latitude) # latitude is never cyclic + FT = eltype(target) + return ColumnInfo(i⁻, i⁺, j⁻, j⁺, FT(wx), FT(wy), col.interpolation) +end + +# 360 if `λc` spans the full globe (cyclic), else `nothing`. +function infer_longitudinal_period(λc) + length(λc) < 2 && return nothing + Δ = λc[2] - λc[1] + span = λc[end] - λc[1] + Δ + return span ≈ 360 ? 360 : nothing +end + +# Cyclic-aware bracketing. With `period`, the cell between `coords[end]` and +# `coords[1] + period` is the wrap cell: returns `(n, 1, w)` so the blend +# reads `data[n, …]` and `data[1, …]`. +function bracket_with_weight(coords, x; period = nothing) + n = length(coords) + + if !isnothing(period) + x = coords[1] + mod(x - coords[1], period) + if x > coords[end] + Δ = (coords[1] + period) - coords[end] + w = (x - coords[end]) / Δ + return n, 1, clamp(w, 0, 1) + end + end + + i⁺ = searchsortedfirst(coords, x) + i⁺ = clamp(i⁺, 2, n) + i⁻ = i⁺ - 1 + Δ = coords[i⁺] - coords[i⁻] + w = Δ == 0 ? zero(x) : (x - coords[i⁻]) / Δ + return i⁻, i⁺, clamp(w, 0, 1) +end + +# `mangling_for` detects a file/grid lat-axis offset from the data shape. +function mangling_for(metadata, data_lat_count) + Ny = size(metadata)[2] + return data_lat_count == Ny - 1 ? ShiftSouth() : + data_lat_count == Ny + 1 ? AverageNorthSouth() : + nothing +end + +# `read_data(data, i, j, k, region, mangling, FT)` returns the file value at +# the grid's (i, j, k) as `FT`, with `Missing` already converted to NaN. +@inline read_data(data, i, j, k, ::Nothing, mangling, FT) = nan_convert_missing(FT, mangle(i, j, k, data, mangling)) +@inline read_data(data, i, j, k, b::BBoxOffset, mangling, FT) = nan_convert_missing(FT, mangle(i + b.di, j + b.dj, k, data, mangling)) +@inline read_data(data, _, _, k, c::ColumnInfo, mangling, FT) = blend(data, c, k, mangling, c.ℑ, FT) + +# NaN-aware bilinear blend: drop NaN corners and renormalise weights. Each +# corner is `nan_convert_missing`-coerced to `FT` before mixing, so `Missing` +# values from NetCDF arrays don't poison the arithmetic. Returns NaN only +# when all four corners are land. +@inline function blend(data, c, k, mangling, ::Linear, FT) + d00 = nan_convert_missing(FT, mangle(c.i⁻, c.j⁻, k, data, mangling)) + d10 = nan_convert_missing(FT, mangle(c.i⁺, c.j⁻, k, data, mangling)) + d01 = nan_convert_missing(FT, mangle(c.i⁻, c.j⁺, k, data, mangling)) + d11 = nan_convert_missing(FT, mangle(c.i⁺, c.j⁺, k, data, mangling)) + w00 = (1 - c.wx) * (1 - c.wy) * !isnan(d00) + w10 = c.wx * (1 - c.wy) * !isnan(d10) + w01 = (1 - c.wx) * c.wy * !isnan(d01) + w11 = c.wx * c.wy * !isnan(d11) + Σw = w00 + w10 + w01 + w11 + Σw == 0 && return convert(FT, NaN) + z = zero(FT) + return (w00 * ifelse(isnan(d00), z, d00) + + w10 * ifelse(isnan(d10), z, d10) + + w01 * ifelse(isnan(d01), z, d01) + + w11 * ifelse(isnan(d11), z, d11)) / Σw +end + +@inline function blend(data, c, k, mangling, ::Nearest, FT) + i = c.wx ≥ 0.5 ? c.i⁺ : c.i⁻ + j = c.wy ≥ 0.5 ? c.j⁺ : c.j⁻ + near = nan_convert_missing(FT, mangle(i, j, k, data, mangling)) + # If the closest corner is land, fall back to the NaN-aware Linear blend. + return isnan(near) ? blend(data, c, k, mangling, Linear(), FT) : near +end + +@kernel function _set_region_kernel!(dst, data, region, mangling, conversion, FT) + i, j, k = @index(Global, NTuple) + d = read_data(data, i, j, k, region, mangling, FT) + d = convert_units(d, conversion) + @inbounds dst[i, j, k] = d +end + +""" + set_region_data!(target, data, λc, φc, metadata) + +Fill the region of `target` (Field or FieldTimeSeries) implied by `metadata.region` from `data`, +applying mangling, NaN conversion, and unit conversion in a single GPU-friendly kernel pass. +""" +function set_region_data!(target::Field, data, λc, φc, metadata; + mangling = mangling_for(metadata, size(data, 2)), + conversion = conversion_units(metadata)) + + region = region_info(metadata.region, target, λc, φc) + FT = eltype(target) + grid = target.grid + arch = architecture(grid) + data = on_architecture(arch, data) + launch!(arch, grid, :xyz, _set_region_kernel!, interior(target), data, region, mangling, conversion, FT) + return nothing +end + +function set_region_data!(target::FieldTimeSeries, data, λc, φc, metadata; + mangling = mangling_for(metadata, size(data, 2)), + conversion = conversion_units(metadata), + slot_indices = 1:size(target, 4)) + + region = region_info(metadata.region, target, λc, φc) + grid = target.grid + arch = architecture(grid) + FT = eltype(target) + data = on_architecture(arch, data) + for (data_time, slot_time) in zip(axes(data, 4), slot_indices) + dest = view(interior(target), :, :, :, slot_time) + slice = view(data, :, :, :, data_time) + launch!(arch, grid, :xyz, _set_region_kernel!, dest, slice, region, mangling, conversion, FT) + end + return nothing +end diff --git a/test/test_column_field.jl b/test/test_column_field.jl index ea33bad6c..16ea52f50 100644 --- a/test/test_column_field.jl +++ b/test/test_column_field.jl @@ -3,8 +3,8 @@ include("runtests_setup.jl") using NumericalEarth.DataWrangling: Column, Linear, Nearest, BoundingBox, native_grid, restrict_location, dataset_location - -using NumericalEarth.DataWrangling: extract_column! +using NumericalEarth.DataWrangling: bracket_with_weight, infer_lon_period, + region_info, blend, ColumnInfo using Oceananigans using Oceananigans.BoundaryConditions: fill_halo_regions! @@ -15,121 +15,96 @@ using Oceananigans.Grids: λnodes, φnodes, topology, Flat, Bounded, Periodic const test_longitude = 12.0 const test_latitude = -50.0 -@testset "extract_column! with Nearest interpolation" begin - for arch in test_architectures - A = typeof(arch) - @testset "Nearest extraction on $A" begin - # Create a LatitudeLongitudeGrid with spatially varying data - intermediate_grid = LatitudeLongitudeGrid(arch; - size = (4, 4, 2), - longitude = (0, 4), - latitude = (0, 4), - z = (-20, 0)) - - intermediate_field = CenterField(intermediate_grid) - - # Set distinct values at each horizontal point - for i in 1:4, j in 1:4, k in 1:2 - @allowscalar intermediate_field[i, j, k] = 10 * i + j + 0.1 * k - end - fill_halo_regions!(intermediate_field) +@testset "bracket_with_weight (non-cyclic)" begin + coords = [0.5, 1.5, 2.5, 3.5] - # Column near grid point (3, 2) → lon≈2.5, lat≈1.5 - col = Column(2.5, 1.5; interpolation=Nearest()) - column_grid = RectilinearGrid(arch; - size = 2, - x = 2.5, - y = 1.5, - z = (-20, 0), - halo = 3, - topology = (Flat, Flat, Bounded)) + # Interior point, exact midpoint between cells. + i⁻, i⁺, w = bracket_with_weight(coords, 2.0) + @test (i⁻, i⁺) == (2, 3) + @test w ≈ 0.5 - column_field = Field{Nothing, Nothing, Center}(column_grid) + # Off-grid below: clamps to first interval. + i⁻, i⁺, w = bracket_with_weight(coords, -1.0) + @test (i⁻, i⁺) == (1, 2) + @test w == 0.0 - extract_column!(column_field, intermediate_field, col, Nearest()) + # Off-grid above: clamps to last interval. + i⁻, i⁺, w = bracket_with_weight(coords, 5.0) + @test (i⁻, i⁺) == (3, 4) + @test w == 1.0 - # Find expected nearest indices - λnodes_arr = λnodes(intermediate_grid, Center(); with_halos=false) - φnodes_arr = φnodes(intermediate_grid, Center(); with_halos=false) - i★ = argmin(abs.(λnodes_arr .- 2.5)) - j★ = argmin(abs.(φnodes_arr .- 1.5)) + # On the right-most centre: weight = 1. + i⁻, i⁺, w = bracket_with_weight(coords, 3.5) + @test (i⁻, i⁺) == (3, 4) + @test w ≈ 1.0 +end - @allowscalar begin - for k in 1:2 - @test column_field[1, 1, k] == intermediate_field[i★, j★, k] - end - end - end +@testset "bracket_with_weight (cyclic wrap)" begin + coords = collect(0.5:1.0:359.5) # global 1° centres + n = length(coords) - @testset "Nearest extraction preserves vertical profile on $A" begin - intermediate_grid = LatitudeLongitudeGrid(arch; - size = (3, 3, 5), - longitude = (10, 13), - latitude = (40, 43), - z = (-50, 0)) + # Interior point — period is a no-op there. + i⁻, i⁺, w = bracket_with_weight(coords, 180.0; period = 360) + @test (i⁻, i⁺) == (180, 181) + @test w ≈ 0.5 - intermediate_field = CenterField(intermediate_grid) + # Wrap cell: x just below the period boundary. + i⁻, i⁺, w = bracket_with_weight(coords, 359.99; period = 360) + @test (i⁻, i⁺) == (n, 1) + @test 0 < w < 1 - # Set a vertical profile: value = depth level - for k in 1:5 - interior(intermediate_field)[:, :, k] .= Float64(k) - end - fill_halo_regions!(intermediate_field) + # x past the period: mod wraps it back into the regular range. + i⁻, i⁺, w = bracket_with_weight(coords, 360.5; period = 360) + @test i⁻ == 1 + @test i⁺ == 2 || (i⁻ == n && i⁺ == 1) # right at coords[1] boundary +end - col = Column(11.5, 41.5; interpolation=Nearest()) - column_grid = RectilinearGrid(arch; - size = 5, - x = 11.5, - y = 41.5, - z = (-50, 0), - halo = 3, - topology = (Flat, Flat, Bounded)) +@testset "infer_lon_period" begin + @test infer_lon_period(collect(0.5:1.0:359.5)) == 360 + @test infer_lon_period(collect(-179.75:0.5:179.75)) == 360 + @test infer_lon_period([10.0, 11.0, 12.0]) === nothing + @test infer_lon_period([100.0]) === nothing +end - column_field = Field{Nothing, Nothing, Center}(column_grid) - extract_column!(column_field, intermediate_field, col, Nearest()) +@testset "NaN-aware blend" begin + # 2x2x1 synthetic data; column at the centre point with equal weights. + c = ColumnInfo(1, 2, 1, 2, 0.5f0, 0.5f0, Linear()) + FT = Float32 - @allowscalar begin - for k in 1:5 - @test column_field[1, 1, k] == k - end - end - end - end + # All-valid: result is the simple average. + data_full = reshape(Float32[1 2; 3 4], 2, 2, 1) + @test blend(data_full, c, 1, nothing, c.ℑ, FT) ≈ 2.5f0 + + # All-NaN: result is NaN. + data_nan = fill(NaN32, 2, 2, 1) + @test isnan(blend(data_nan, c, 1, nothing, c.ℑ, FT)) + + # Partial: bottom-right corner is NaN, weights renormalise over the rest. + # Weights become (0.25, 0.25, 0.25, 0); Σw = 0.75; sum = 1+2+3 = 6; + # result = 6 / 0.75 = 2.0 in 1/2/3 → renormalised mean. + data_part = reshape(Float32[1 2; 3 NaN32], 2, 2, 1) + @test blend(data_part, c, 1, nothing, c.ℑ, FT) ≈ 2.0f0 + + # Missing values from NetCDF-style arrays are treated as NaN. + data_missing = reshape(Union{Missing, Float32}[1.0 2.0; 3.0 missing], 2, 2, 1) + @test blend(data_missing, c, 1, nothing, c.ℑ, FT) ≈ 2.0f0 end -@testset "extract_column! dispatch routes on interpolation type" begin - for arch in test_architectures - A = typeof(arch) - @testset "Dispatch on $A" begin - intermediate_grid = LatitudeLongitudeGrid(arch; - size = (4, 4, 2), - longitude = (0, 4), - latitude = (0, 4), - z = (-20, 0)) - - intermediate_field = CenterField(intermediate_grid) - interior(intermediate_field) .= 42.0 - fill_halo_regions!(intermediate_field) - - column_grid = RectilinearGrid(arch; - size = 2, - x = 2.0, - y = 2.0, - z = (-20, 0), - halo = 3, - topology = (Flat, Flat, Bounded)) - - # Column dispatch routes to the correct method - col_nearest = Column(2.0, 2.0; interpolation=Nearest()) - cf = Field{Nothing, Nothing, Center}(column_grid) - extract_column!(cf, intermediate_field, col_nearest) +@testset "blend dispatches Linear vs Nearest" begin + data = reshape(Float32[1 2; 3 4], 2, 2, 1) + FT = Float32 - @allowscalar begin - @test cf[1, 1, 1] == 42.0 - @test cf[1, 1, 2] == 42.0 - end - end - end + # wx = wy = 0.5 → average for Linear; arbitrary corner for Nearest. + c_lin = ColumnInfo(1, 2, 1, 2, 0.5f0, 0.5f0, Linear()) + @test blend(data, c_lin, 1, nothing, c_lin.ℑ, FT) ≈ 2.5f0 + + # wx = 0.7, wy = 0.7 → both above 0.5 → picks i⁺, j⁺ = data[2,2,1] = 4. + c_near = ColumnInfo(1, 2, 1, 2, 0.7f0, 0.7f0, Nearest()) + @test blend(data, c_near, 1, nothing, c_near.ℑ, FT) ≈ 4.0f0 + + # wx = 0.3, wy = 0.3 → both below 0.5 → picks i⁻, j⁻ = data[1,1,1] = 1. + c_near2 = ColumnInfo(1, 2, 1, 2, 0.3f0, 0.3f0, Nearest()) + @test blend(data, c_near2, 1, nothing, c_near2.ℑ, FT) ≈ 1.0f0 end @testset "End-to-end Column Field construction" begin diff --git a/test/test_dataset_region.jl b/test/test_dataset_region.jl new file mode 100644 index 000000000..92e151921 --- /dev/null +++ b/test/test_dataset_region.jl @@ -0,0 +1,63 @@ +include("runtests_setup.jl") + +using NumericalEarth.DataWrangling: BoundingBox, Column, metadata_path +using Oceananigans: λnodes, φnodes, Center, interior +using Oceananigans.Fields: Field +using NCDatasets +using Dates + +@testset "Cross-dataset region support (snapshot path)" begin + arch = CPU() + + @testset "ECCO4 BoundingBox loads the right window" begin + # ECCO4Monthly stores longitude on [-180, 180]; pick a bbox far from + # the file's SW corner (which is the Antarctic ocean, all zeros) so + # that positional vs. coordinate indexing differ. + bbox = BoundingBox(longitude=(-60, 60), latitude=(-30, 30)) + md = Metadatum(:temperature; dataset=ECCO4Monthly(), + date=DateTime(1993, 1, 1), region=bbox) + f = Field(md, arch; inpainting=nothing, cache_inpainted_data=false) + + # Grid coordinates must fall inside the requested bbox (with a 1° + # tolerance for ECCO4's 0.5° spacing). + λg = λnodes(f.grid, Center()) + φg = φnodes(f.grid, Center()) + @test minimum(λg) ≥ -60 - 1.0 + @test maximum(λg) ≤ 60 + 1.0 + @test minimum(φg) ≥ -30 - 1.0 + @test maximum(φg) ≤ 30 + 1.0 + @test any(!iszero, interior(f)) + + # Correctness: hand-extract the reference value at (0°, 0°, surface) + # directly from the NetCDF file and compare against the bbox-restricted + # field at the same physical point. This catches the silent SW-corner + # positional-indexing bug in `set_metadata_field!`. + path = metadata_path(md) + ds = Dataset(path) + # ECCO4 surface is k=1 in the raw file (Z=-5 m). After + # `retrieve_data` reverses dims=3, surface becomes k=end in + # `interior(f)`. + T_full = ds["THETA"][:, :, 1, 1] + λfile = ds["longitude"][:] + φfile = ds["latitude"][:] + close(ds) + i★ = argmin(abs.(λfile .- 0)) + j★ = argmin(abs.(φfile .- 0)) + T_ref = T_full[i★, j★] + # Same physical point in the bbox-restricted field. + i_grid = argmin(abs.(λg .- 0)) + j_grid = argmin(abs.(φg .- 0)) + T_field = interior(f)[i_grid, j_grid, end] + @test T_field ≈ T_ref rtol=1e-3 + end + + @testset "ECCO4 Column extracts a single point" begin + col = Column(150.0, 0.0) + md = Metadatum(:temperature; dataset=ECCO4Monthly(), + date=DateTime(1993, 1, 1), region=col) + f = Field(md, arch; inpainting=nothing, cache_inpainted_data=false) + @test size(f.grid, 1) == 1 + @test size(f.grid, 2) == 1 + @test any(!iszero, interior(f)) + end +end diff --git a/test/test_jra55_region.jl b/test/test_jra55_region.jl new file mode 100644 index 000000000..7ea12670f --- /dev/null +++ b/test/test_jra55_region.jl @@ -0,0 +1,62 @@ +include("runtests_setup.jl") + +using NumericalEarth.DataWrangling: BoundingBox, Column, Linear +using Oceananigans.Fields: interpolate as oc_interpolate +using Oceananigans.Grids: topology, Bounded + +@testset "JRA55 region support" begin + arch = CPU() + + @testset "BoundingBox slices the right window" begin + bbox = BoundingBox(longitude=(120, 240), latitude=(-30, 30)) + atm = JRA55PrescribedAtmosphere(arch; + time_indices_in_memory=2, + include_rivers_and_icebergs=false, + region=bbox) + Ta = atm.tracers.T + # Coordinates of the field grid should fall inside the requested bbox. + λnodes_T = λnodes(Ta.grid, Center()) + φnodes_T = φnodes(Ta.grid, Center()) + @test minimum(λnodes_T) ≥ 120 - 1.5 # 1.5° JRA55 spacing tolerance + @test maximum(λnodes_T) ≤ 240 + 1.5 + @test minimum(φnodes_T) ≥ -30 - 1.5 + @test maximum(φnodes_T) ≤ 30 + 1.5 + @test any(!iszero, interior(Ta)) + @test !any(isnan, interior(Ta)) + # Sub-360° span must be Bounded in x so halos do not wrap. + @test topology(Ta.grid)[1] == Bounded + end + + @testset "Column extracts a single point" begin + col = Column(150.0, 0.0) # equator, central Pacific + atm = JRA55PrescribedAtmosphere(arch; + time_indices_in_memory=2, + include_rivers_and_icebergs=false, + region=col) + Ta = atm.tracers.T + @test size(Ta.grid, 1) == 1 + @test size(Ta.grid, 2) == 1 + @test any(!iszero, interior(Ta)) + @test !any(isnan, interior(Ta)) + end + + @testset "Column matches bbox bilinear at the column point" begin + # The Column dispatch should produce the same value as bilinearly + # interpolating a bbox-extracted FTS at the column's (lon, lat). + col_atm = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=2, + include_rivers_and_icebergs=false, + region=Column(150.0, 0.0)) + bbox_atm = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=2, + include_rivers_and_icebergs=false, + region=BoundingBox(longitude=(148, 152), + latitude=(-2, 2))) + + Ta_col = col_atm.tracers.T + Ta_bbox = bbox_atm.tracers.T + + T_col_t1 = interior(Ta_col)[1, 1, 1, 1] + loc = (Center(), Center(), Center()) + T_bbox_t1 = oc_interpolate((150.0, 0.0, 0.0), Ta_bbox[1], loc, Ta_bbox.grid) + @test T_col_t1 ≈ T_bbox_t1 rtol = 1e-3 + end +end diff --git a/test/test_mangling.jl b/test/test_mangling.jl new file mode 100644 index 000000000..bd2886ea4 --- /dev/null +++ b/test/test_mangling.jl @@ -0,0 +1,48 @@ +include("runtests_setup.jl") + +using NumericalEarth.DataWrangling: mangle, mangling_for, ShiftSouth, AverageNorthSouth + +using Oceananigans.Fields: location + +@testset "mangle dispatch" begin + data = reshape(Float32[1 2 3; 4 5 6; 7 8 9], 3, 3, 1) + + # Identity: reads (i, j, k) directly. + @test mangle(2, 2, 1, data, nothing) == 5 + + # ShiftSouth: reads (i, j-1, k); j=1 clamps to j=1. + @test mangle(2, 2, 1, data, ShiftSouth()) == 4 + @test mangle(2, 1, 1, data, ShiftSouth()) == 4 + + # AverageNorthSouth: averages (i, j, k) and (i, j+1, k). + @test mangle(2, 1, 1, data, AverageNorthSouth()) ≈ 4.5f0 + @test mangle(2, 2, 1, data, AverageNorthSouth()) ≈ 5.5f0 +end + +@testset "mangling_for size dispatch" begin + md = Metadatum(:v_velocity; dataset=ECCO4Monthly(), date=start_date) + _, Ny, _, _ = size(md) + + @test mangling_for(md, Ny) === nothing + @test mangling_for(md, Ny - 1) isa ShiftSouth + @test mangling_for(md, Ny + 1) isa AverageNorthSouth + @test mangling_for(md, Ny - 2) === nothing + @test mangling_for(md, Ny + 2) === nothing +end + +@testset "ECCO v_velocity Field uses ShiftSouth mangling end-to-end" begin + for arch in test_architectures + md = Metadatum(:v_velocity; dataset=ECCO4Monthly(), date=start_date) + field = Field(md, arch) + + # v lives on the latitude Face for ECCO; the file ships Ny-1 lat + # entries, so mangling_for must select ShiftSouth — anything else + # would either go out of bounds or read off-by-one everywhere. + @test location(field) == (Center, Face, Center) + + @allowscalar begin + @test any(!=(0), interior(field)) + @test any(isfinite, interior(field)) + end + end +end diff --git a/test/test_metadata.jl b/test/test_metadata.jl index e55fd2fdf..8edbe356c 100644 --- a/test/test_metadata.jl +++ b/test/test_metadata.jl @@ -3,6 +3,7 @@ include("runtests_setup.jl") using NumericalEarth.DataWrangling: Column, Linear, Nearest, BoundingBox, dataset_location, restrict_location, native_grid +using NumericalEarth.DataWrangling: restrict using Oceananigans: RectilinearGrid, LatitudeLongitudeGrid, location using Oceananigans.Grids: topology, Flat, Bounded, Periodic @@ -111,6 +112,19 @@ end Nx, Ny, Nz = size(grid) @test Nx < Nx_full @test Ny < Ny_full + + # Sub-360° bbox must be Bounded in x (not Periodic) so halos don't wrap. + @test topology(grid)[1] == Bounded + + # 360°-spanning bbox keeps Periodic in x. + bbox_full = BoundingBox(longitude=(-180, 180), latitude=(-30, 30)) + md_full = Metadatum(:temperature; dataset=ECCO4Monthly(), region=bbox_full) + @test topology(native_grid(md_full))[1] == Periodic + + # Latitude-only restriction: longitude is unrestricted, x stays Periodic. + bbox_lat = BoundingBox(latitude=(-30, 30)) + md_lat = Metadatum(:temperature; dataset=ECCO4Monthly(), region=bbox_lat) + @test topology(native_grid(md_lat))[1] == Periodic end @testset "Metadata region keyword" begin @@ -141,3 +155,38 @@ end @test last(md).region === col @test md[1].region === col end + +@testset "restrict() snaps to native interfaces" begin + # Uniform interfaces: snapping coincides with the user's bbox if it + # already lies on cell boundaries. + interfaces = collect(0.0:1.0:10.0) + sliced, rN = restrict((2.0, 6.0), interfaces, 10) + @test sliced == [2.0, 3.0, 4.0, 5.0, 6.0] + @test rN == 4 + + # Uniform interfaces, off-grid bbox: snap outward to the surrounding + # native cells so the result is a superset of the request. + sliced, rN = restrict((2.5, 6.5), interfaces, 10) + @test sliced[1] ≤ 2.5 + @test sliced[end] ≥ 6.5 + @test rN == length(sliced) - 1 + + # Stretched interfaces (cells get wider): snapping must return the + # actual native interfaces, not a 2-tuple of the user's bbox. + stretched = [0.0, 0.5, 1.5, 3.0, 5.5, 9.5, 15.0] + sliced, rN = restrict((1.0, 6.0), stretched, length(stretched) - 1) + @test sliced == [0.5, 1.5, 3.0, 5.5, 9.5] + @test rN == 4 + + # Out-of-range bbox is clamped, not crashed. + sliced, rN = restrict((-100.0, 100.0), stretched, length(stretched) - 1) + @test sliced == stretched + @test rN == length(stretched) - 1 + + # 2-tuple endpoints (the convention used by `longitude_interfaces` / + # `latitude_interfaces` for uniform native grids like JRA55/ECCO). + sliced, rN = restrict((120, 240), (0, 360), 360) + @test minimum(sliced) ≤ 120 + @test maximum(sliced) ≥ 240 + @test rN == length(sliced) - 1 +end From 7b76d7c2cb4a037fefb0b02bb8eb6c31d388def2 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Tue, 28 Apr 2026 15:21:43 +0200 Subject: [PATCH 29/44] restore CI --- .github/workflows/ci.yml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 678e872fb..1171dc75a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -108,8 +108,6 @@ jobs: env: GPU_TEST: "true" OPENBLAS_NUM_THREADS: 1 - # Keep workers alive past ParallelTestRunner's default 3800 MB recycle threshold - JULIA_TEST_MAXRSS_MB: "7000" # Redirect temp to the workspace volume (bind-mounted from the # host's EBS disk) to avoid filling the container's overlay filesystem with # large data downloads and compiled artifacts. @@ -138,7 +136,7 @@ jobs: run: | # Run tests in verbose mode TEST_ARGS=(--verbose) - TEST_ARGS+=(--jobs=$(($(nproc) - 2))) + TEST_ARGS+=(--jobs=$(($(nproc) - 1))) echo "runtest_test_args=${TEST_ARGS[@]}" | tee -a "${GITHUB_ENV}" - name: Update registry shell: julia --color=yes {0} @@ -149,7 +147,7 @@ jobs: Pkg.Registry.update() - name: Run tests run: | - earlyoom -m 3 -s 100 -r 300 & + earlyoom -m 3 -s 100 -r 300 --prefer 'julia' & julia --project --color=yes --check-bounds=yes -e ' using Pkg; Pkg.test(; coverage=true, From a05d4dc2d78b82bd2fffcbfa74331c65f36d5572 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Tue, 28 Apr 2026 15:31:18 +0200 Subject: [PATCH 30/44] more fixes --- .../JRA55/JRA55_field_time_series.jl | 46 +++---------------- 1 file changed, 7 insertions(+), 39 deletions(-) diff --git a/src/DataWrangling/JRA55/JRA55_field_time_series.jl b/src/DataWrangling/JRA55/JRA55_field_time_series.jl index efeedc6b0..dfd90e635 100644 --- a/src/DataWrangling/JRA55/JRA55_field_time_series.jl +++ b/src/DataWrangling/JRA55/JRA55_field_time_series.jl @@ -15,14 +15,10 @@ const JRA55NetCDFFTSMultipleYears = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:Da """ retrieve_data(metadatum::JRA55Metadatum) -Read the 2D slice from the JRA55 NetCDF file corresponding to `metadatum`'s -single date. JRA55 files chunk the series by calendar year and use a -`DateTimeNoLeap` (365-day) calendar internally, so the file-local index -must be resolved against either the file's own time axis (the safe path -for `MultiYearJRA55`, which spans real leap years) or the position within -the year's `all_dates` (which is unambiguous for `RepeatYearJRA55`, -because the repeat year — 1990 — is itself non-leap and the file holds -exactly 2920 entries that align 1:1 with `all_dates`). +Read the 2D slice from the JRA55 NetCDF file corresponding to `metadatum`'s single date. +`RepeatYearJRA55` resolves the index from the position within `all_dates` (the file holds +exactly 2920 entries for 1990 and aligns 1:1 with `all_dates`); `MultiYearJRA55` resolves +the index against the file's own `time` axis (one Gregorian-calendar file per year). """ function retrieve_data(metadatum::RepeatYearJRA55Metadatum) path = metadata_path(metadatum) @@ -47,14 +43,11 @@ function retrieve_data(metadatum::MultiYearJRA55Metadatum) ds = Dataset(path) file_dates = ds["time"][:] - file_idx = jra55_no_leap_file_index(file_dates, metadatum.dates) + file_idx = findfirst(==(metadatum.dates), file_dates) if isnothing(file_idx) close(ds) - throw(ArgumentError(string("Date ", metadatum.dates, - " not found in JRA55 multi-year file ", path, - " (note: JRA55 multi-year files use a no-leap calendar; ", - "Feb 29 of leap years has no corresponding file entry)."))) + throw(ArgumentError(string("Date ", metadatum.dates, " not found in JRA55 multi-year file ", path, "."))) end data = ds[name][:, :, file_idx] @@ -62,29 +55,6 @@ function retrieve_data(metadatum::MultiYearJRA55Metadatum) return data end -# Find the file-time index whose calendar components (Y/M/D/H/min) match -# the target date. Calendar-component matching avoids the -# `DateTimeNoLeap` ↔ `DateTime` epoch / leap-day mismatch that would -# otherwise break naive arithmetic-based lookup. -function jra55_no_leap_file_index(file_dates, target) - return findfirst(file_dates) do d - !ismissing(d) && - Dates.year(d) == Dates.year(target) && - Dates.month(d) == Dates.month(target) && - Dates.day(d) == Dates.day(target) && - Dates.hour(d) == Dates.hour(target) && - Dates.minute(d) == Dates.minute(target) - end -end - -# Note that each file should have the variables -# - ds["time"]: time coordinate -# - ds["lon"]: longitude at the location of the variable -# - ds["lat"]: latitude at the location of the variable -# - ds["lon_bnds"]: bounding longitudes between which variables are averaged -# - ds["lat_bnds"]: bounding latitudes between which variables are averaged -# - ds[shortname]: the variable data - # Split at the wrap point if `nn` cycles past 1 — DiskArrays requires sorted indices. function jra55_read_data(ds, name, i, j, nn) if issorted(nn) @@ -115,8 +85,6 @@ function set!(fts::JRA55NetCDFFTSRepeatYear, backend=fts.backend) return nothing end -# JRA55 multi-year files use the no-leap calendar; matching by date components -# sidesteps the seconds-since-start drift across leap days. function set!(fts::JRA55NetCDFFTSMultipleYears, backend=fts.backend) metadata = backend.metadata name = dataset_variable_name(metadata) @@ -132,7 +100,7 @@ function set!(fts::JRA55NetCDFFTSMultipleYears, backend=fts.backend) nn = Int[] ftsn_loc = Int[] for (loc, slot_date) in enumerate(slot_dates) - file_idx = jra55_no_leap_file_index(file_dates, slot_date) + file_idx = findfirst(==(slot_date), file_dates) if !isnothing(file_idx) push!(nn, file_idx) push!(ftsn_loc, loc) From e60e24939c02405b5a45e0fc26a365397f08293b Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Tue, 28 Apr 2026 15:46:31 +0200 Subject: [PATCH 31/44] small bugfix --- src/DataWrangling/set_region_data.jl | 32 +++++++++++++--------------- test/test_column_field.jl | 19 +++++++++++------ 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/src/DataWrangling/set_region_data.jl b/src/DataWrangling/set_region_data.jl index c2bd910a4..f7bf8db6b 100644 --- a/src/DataWrangling/set_region_data.jl +++ b/src/DataWrangling/set_region_data.jl @@ -40,9 +40,8 @@ end struct ShiftSouth end struct AverageNorthSouth end -# `mangle(i, j, k, data, mangling)` reads file `data` at metadata-grid index -# `(i, j, k)`, accounting for staggered lat-axis offsets. Used inside the -# region-aware kernel. +# `mangle(i, j, k, data, mangling)` reads file `data` at metadata-grid index `(i, j, k)`, accounting +# for staggered lat-axis offsets. Used inside the region-aware kernel. @inline mangle(i, j, k, data, ::Nothing) = @inbounds data[i, j, k] @inline mangle(i, j, k, data, ::ShiftSouth) = @inbounds data[i, max(j - 1, 1), k] @inline mangle(i, j, k, data, ::AverageNorthSouth) = @inbounds (data[i, j, k] + data[i, j + 1, k]) / 2 @@ -95,12 +94,14 @@ function infer_longitudinal_period(λc) return span ≈ 360 ? 360 : nothing end -# Cyclic-aware bracketing. With `period`, the cell between `coords[end]` and -# `coords[1] + period` is the wrap cell: returns `(n, 1, w)` so the blend -# reads `data[n, …]` and `data[1, …]`. +# Cyclic-aware bracketing. With `period`, the cell between `coords[end]` and `coords[1] + period` is the wrap cell: +# returns `(n, 1, w)` so the blend reads `data[n, …]` and `data[1, …]`. function bracket_with_weight(coords, x; period = nothing) n = length(coords) - + + # Single-cell axis: nothing to bracket — point both corners at the only cell. + n ≤ 1 && return 1, 1, zero(x) + if !isnothing(period) x = coords[1] + mod(x - coords[1], period) if x > coords[end] @@ -127,15 +128,13 @@ function mangling_for(metadata, data_lat_count) end # `read_data(data, i, j, k, region, mangling, FT)` returns the file value at -# the grid's (i, j, k) as `FT`, with `Missing` already converted to NaN. +# the grid's (i, j, k) as `FT`, with `Missing` converted to NaN. @inline read_data(data, i, j, k, ::Nothing, mangling, FT) = nan_convert_missing(FT, mangle(i, j, k, data, mangling)) @inline read_data(data, i, j, k, b::BBoxOffset, mangling, FT) = nan_convert_missing(FT, mangle(i + b.di, j + b.dj, k, data, mangling)) @inline read_data(data, _, _, k, c::ColumnInfo, mangling, FT) = blend(data, c, k, mangling, c.ℑ, FT) -# NaN-aware bilinear blend: drop NaN corners and renormalise weights. Each -# corner is `nan_convert_missing`-coerced to `FT` before mixing, so `Missing` -# values from NetCDF arrays don't poison the arithmetic. Returns NaN only -# when all four corners are land. +# NaN-aware bilinear blend: drop NaN corners and renormalise weights. +# Returns NaN only when all four corners are land. @inline function blend(data, c, k, mangling, ::Linear, FT) d00 = nan_convert_missing(FT, mangle(c.i⁻, c.j⁻, k, data, mangling)) d10 = nan_convert_missing(FT, mangle(c.i⁺, c.j⁻, k, data, mangling)) @@ -147,11 +146,10 @@ end w11 = c.wx * c.wy * !isnan(d11) Σw = w00 + w10 + w01 + w11 Σw == 0 && return convert(FT, NaN) - z = zero(FT) - return (w00 * ifelse(isnan(d00), z, d00) + - w10 * ifelse(isnan(d10), z, d10) + - w01 * ifelse(isnan(d01), z, d01) + - w11 * ifelse(isnan(d11), z, d11)) / Σw + return (w00 * ifelse(isnan(d00), zero(FT), d00) + + w10 * ifelse(isnan(d10), zero(FT), d10) + + w01 * ifelse(isnan(d01), zero(FT), d01) + + w11 * ifelse(isnan(d11), zero(FT), d11)) / Σw end @inline function blend(data, c, k, mangling, ::Nearest, FT) diff --git a/test/test_column_field.jl b/test/test_column_field.jl index 16ea52f50..b8814bf66 100644 --- a/test/test_column_field.jl +++ b/test/test_column_field.jl @@ -3,7 +3,7 @@ include("runtests_setup.jl") using NumericalEarth.DataWrangling: Column, Linear, Nearest, BoundingBox, native_grid, restrict_location, dataset_location -using NumericalEarth.DataWrangling: bracket_with_weight, infer_lon_period, +using NumericalEarth.DataWrangling: bracket_with_weight, infer_longitudinal_period, region_info, blend, ColumnInfo using Oceananigans @@ -37,6 +37,13 @@ const test_latitude = -50.0 i⁻, i⁺, w = bracket_with_weight(coords, 3.5) @test (i⁻, i⁺) == (3, 4) @test w ≈ 1.0 + + # Single-cell axis: nothing to bracket; both corners point at the only cell. + # GLORYS via CopernicusMarine returns 1-cell-wide chunked files for Column queries. + i⁻, i⁺, w = bracket_with_weight([7.5], 7.5) + @test (i⁻, i⁺, w) == (1, 1, 0.0) + i⁻, i⁺, w = bracket_with_weight([7.5], 99.0) + @test (i⁻, i⁺, w) == (1, 1, 0.0) end @testset "bracket_with_weight (cyclic wrap)" begin @@ -59,11 +66,11 @@ end @test i⁺ == 2 || (i⁻ == n && i⁺ == 1) # right at coords[1] boundary end -@testset "infer_lon_period" begin - @test infer_lon_period(collect(0.5:1.0:359.5)) == 360 - @test infer_lon_period(collect(-179.75:0.5:179.75)) == 360 - @test infer_lon_period([10.0, 11.0, 12.0]) === nothing - @test infer_lon_period([100.0]) === nothing +@testset "infer_longitudinal_period" begin + @test infer_longitudinal_period(collect(0.5:1.0:359.5)) == 360 + @test infer_longitudinal_period(collect(-179.75:0.5:179.75)) == 360 + @test infer_longitudinal_period([10.0, 11.0, 12.0]) === nothing + @test infer_longitudinal_period([100.0]) === nothing end @testset "NaN-aware blend" begin From 52ef0f4979972e9185f1e4e69d9701014b3068f8 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Tue, 28 Apr 2026 16:19:30 +0200 Subject: [PATCH 32/44] correct WOA climatology --- src/DataWrangling/WOA/WOA.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/DataWrangling/WOA/WOA.jl b/src/DataWrangling/WOA/WOA.jl index 013b4371e..4fa709804 100644 --- a/src/DataWrangling/WOA/WOA.jl +++ b/src/DataWrangling/WOA/WOA.jl @@ -33,6 +33,8 @@ import NumericalEarth.DataWrangling: z_interfaces, longitude_interfaces, latitude_interfaces, + longitude_name, + latitude_name, is_three_dimensional, reversed_vertical_axis, inpainted_metadata_path, @@ -89,6 +91,8 @@ reversed_vertical_axis(::WOAClimatology) = true longitude_interfaces(::WOAClimatology) = (-180, 180) latitude_interfaces(::WOAClimatology) = (-90, 90) +longitude_name(::Metadata{<:WOAClimatology}) = "lon" +latitude_name(::Metadata{<:WOAClimatology}) = "lat" available_variables(::WOAClimatology) = WOA_variable_names """ From 7058a002a93abf2694e91f697594c14f23e77854 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Tue, 28 Apr 2026 16:46:00 +0200 Subject: [PATCH 33/44] fix uniform cells `restrict` --- src/DataWrangling/metadata_field.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/DataWrangling/metadata_field.jl b/src/DataWrangling/metadata_field.jl index 73eafd70f..3eb844da4 100644 --- a/src/DataWrangling/metadata_field.jl +++ b/src/DataWrangling/metadata_field.jl @@ -25,10 +25,12 @@ restrict(::Nothing, interfaces, N) = interfaces, N restrict(::Nothing, interfaces::NTuple{2,Any}, N) = interfaces, N restrict(::Nothing, interfaces::AbstractVector, N) = interfaces, N -# Uniform native grid: 2-tuple endpoints expand to a uniform interface range. +# Uniform native grid: keep the bbox endpoints verbatim with a proportional cell count. function restrict(bbox_interfaces, interfaces::NTuple{2,Any}, N) - full = range(interfaces[1], interfaces[2]; length = N + 1) - return restrict(bbox_interfaces, full, N) + extent = interfaces[2] - interfaces[1] + rΔ = bbox_interfaces[2] - bbox_interfaces[1] + rN = max(round(Int, rΔ / extent * N), 1) + return bbox_interfaces, rN end # Stretched native grid: snap outward to the nearest native cell interfaces. From d3454dd00b2cb5cea21269d8188a7953cde17434 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Tue, 28 Apr 2026 17:10:02 +0200 Subject: [PATCH 34/44] fix longitude and latitude names --- src/DataWrangling/ECCO/ECCO.jl | 9 +++++++++ src/DataWrangling/EN4/EN4.jl | 4 ++++ 2 files changed, 13 insertions(+) diff --git a/src/DataWrangling/ECCO/ECCO.jl b/src/DataWrangling/ECCO/ECCO.jl index e30d929ec..248f3b808 100644 --- a/src/DataWrangling/ECCO/ECCO.jl +++ b/src/DataWrangling/ECCO/ECCO.jl @@ -50,6 +50,8 @@ import NumericalEarth.DataWrangling: metaprefix, longitude_interfaces, latitude_interfaces, + longitude_name, + latitude_name, z_interfaces, is_three_dimensional, inpainted_metadata_path, @@ -112,6 +114,13 @@ longitude_interfaces(::ECCODataset) = (0, 360) longitude_interfaces(::ECCO4Monthly) = (-180, 180) latitude_interfaces(::ECCODataset) = (-90, 90) +longitude_name(::Metadata{<:ECCODataset}) = "LONGITUDE_T" +latitude_name(::Metadata{<:ECCODataset}) = "LATITUDE_T" +longitude_name(::Metadata{<:ECCO4Monthly}) = "longitude" +latitude_name(::Metadata{<:ECCO4Monthly}) = "latitude" +longitude_name(::Metadata{<:ECCO4DarwinMonthly}) = "longitude" +latitude_name(::Metadata{<:ECCO4DarwinMonthly}) = "latitude" + z_interfaces(::ECCODataset) = [ -6128.75, -5683.75, diff --git a/src/DataWrangling/EN4/EN4.jl b/src/DataWrangling/EN4/EN4.jl index 851401439..8e98b1adf 100644 --- a/src/DataWrangling/EN4/EN4.jl +++ b/src/DataWrangling/EN4/EN4.jl @@ -45,6 +45,8 @@ import NumericalEarth.DataWrangling: z_interfaces, longitude_interfaces, latitude_interfaces, + longitude_name, + latitude_name, is_three_dimensional, reversed_vertical_axis, inpainted_metadata_path, @@ -70,6 +72,8 @@ reversed_vertical_axis(::EN4Monthly) = true longitude_interfaces(::EN4Monthly) = (0.5, 360.5) latitude_interfaces(::EN4Monthly) = (-83.5, 89.5) +longitude_name(::Metadata{<:EN4Monthly}) = "lon" +latitude_name(::Metadata{<:EN4Monthly}) = "lat" available_variables(::EN4Monthly) = EN4_dataset_variable_names z_interfaces(::EN4Monthly) = [ From fd577ca2917bfd3a213eb8c4360a6a228ab43c6b Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Tue, 28 Apr 2026 19:51:43 +0200 Subject: [PATCH 35/44] some changes --- src/DataWrangling/set_region_data.jl | 25 +++++++++++++++++++++---- test/test_metadata.jl | 10 ++++------ 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/src/DataWrangling/set_region_data.jl b/src/DataWrangling/set_region_data.jl index f7bf8db6b..73e401360 100644 --- a/src/DataWrangling/set_region_data.jl +++ b/src/DataWrangling/set_region_data.jl @@ -73,10 +73,27 @@ end region_info(::Nothing, target, λc, φc) = nothing function region_info(::BoundingBox, target, λc, φc) - LX, LY, _ = Oceananigans.Fields.location(target) - i₁, _ = compute_bounding_indices(compute_bounding_nodes(target.grid, LX, λnodes), λc) - j₁, _ = compute_bounding_indices(compute_bounding_nodes(target.grid, LY, φnodes), φc) - return BBoxOffset(i₁ - 1, j₁ - 1) + LX, LY, _ = Oceananigans.Fields.location(target) + λmin, λmax = compute_bounding_nodes(target.grid, LX, λnodes) + φmin, φmax = compute_bounding_nodes(target.grid, LY, φnodes) + λmin, λmax = shift_into_range(λmin, λmax, λc) + + i₁, _ = compute_bounding_indices((λmin, λmax), λc) + j₁, _ = compute_bounding_indices((φmin, φmax), φc) + + Nx, Ny, _ = size(target) + di = clamp(i₁ - 1, 0, max(length(λc) - Nx, 0)) + dj = clamp(j₁ - 1, 0, max(length(φc) - Ny, 0)) + return BBoxOffset(di, dj) +end + +# Shift `(a, b)` by ±360° so it falls inside `λc`'s range. +function shift_into_range(a, b, λc) + isempty(λc) && return a, b + lo, hi = λc[1], λc[end] + a > hi && b - 360 ≥ lo && return a - 360, b - 360 + b < lo && a + 360 ≤ hi && return a + 360, b + 360 + return a, b end function region_info(col::Column, target, λc, φc) diff --git a/test/test_metadata.jl b/test/test_metadata.jl index 8edbe356c..247c48dd5 100644 --- a/test/test_metadata.jl +++ b/test/test_metadata.jl @@ -183,10 +183,8 @@ end @test sliced == stretched @test rN == length(stretched) - 1 - # 2-tuple endpoints (the convention used by `longitude_interfaces` / - # `latitude_interfaces` for uniform native grids like JRA55/ECCO). - sliced, rN = restrict((120, 240), (0, 360), 360) - @test minimum(sliced) ≤ 120 - @test maximum(sliced) ≥ 240 - @test rN == length(sliced) - 1 + # 2-tuple endpoints: uniform native grids return the bbox endpoints + # verbatim (no snap) with a proportional cell count. + @test sliced == (120, 240) + @test rN == 120 end From 4613e681f6b772c8ad635e6dbfcce7d9544624f0 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Tue, 28 Apr 2026 21:30:52 +0200 Subject: [PATCH 36/44] use Oceananigans' method --- src/DataWrangling/set_region_data.jl | 17 +++++++---------- test/test_metadata.jl | 5 ++++- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/DataWrangling/set_region_data.jl b/src/DataWrangling/set_region_data.jl index 73e401360..4eceb98bb 100644 --- a/src/DataWrangling/set_region_data.jl +++ b/src/DataWrangling/set_region_data.jl @@ -2,6 +2,7 @@ using Oceananigans.Utils: launch! using Oceananigans.Architectures: AbstractArchitecture, architecture using Oceananigans.Grids: AbstractGrid, Periodic, Bounded, λnodes, φnodes using Oceananigans.Fields: Field, interior, interpolate! +using Oceananigans.Fields: convert_to_λ₀_λ₀_plus360 using GPUArraysCore: @allowscalar ##### @@ -76,7 +77,12 @@ function region_info(::BoundingBox, target, λc, φc) LX, LY, _ = Oceananigans.Fields.location(target) λmin, λmax = compute_bounding_nodes(target.grid, LX, λnodes) φmin, φmax = compute_bounding_nodes(target.grid, LY, φnodes) - λmin, λmax = shift_into_range(λmin, λmax, λc) + + # Shift the target's longitude into the file's `[λc[1], λc[1]+360)` + if !isempty(λc) + λmin = convert_to_λ₀_λ₀_plus360(λmin, λc[1]) + λmax = convert_to_λ₀_λ₀_plus360(λmax, λc[1]) + end i₁, _ = compute_bounding_indices((λmin, λmax), λc) j₁, _ = compute_bounding_indices((φmin, φmax), φc) @@ -87,15 +93,6 @@ function region_info(::BoundingBox, target, λc, φc) return BBoxOffset(di, dj) end -# Shift `(a, b)` by ±360° so it falls inside `λc`'s range. -function shift_into_range(a, b, λc) - isempty(λc) && return a, b - lo, hi = λc[1], λc[end] - a > hi && b - 360 ≥ lo && return a - 360, b - 360 - b < lo && a + 360 ≤ hi && return a + 360, b + 360 - return a, b -end - function region_info(col::Column, target, λc, φc) i⁻, i⁺, wx = bracket_with_weight(λc, col.longitude; period = infer_longitudinal_period(λc)) j⁻, j⁺, wy = bracket_with_weight(φc, col.latitude) # latitude is never cyclic diff --git a/test/test_metadata.jl b/test/test_metadata.jl index 247c48dd5..389794cd9 100644 --- a/test/test_metadata.jl +++ b/test/test_metadata.jl @@ -184,7 +184,10 @@ end @test rN == length(stretched) - 1 # 2-tuple endpoints: uniform native grids return the bbox endpoints - # verbatim (no snap) with a proportional cell count. + # verbatim (no snap) with a proportional cell count. Stays correct across + # longitude conventions for pre-subsetted files (e.g. GLORYS via Copernicus); + # the centre alignment is handled at read time by `region_info`. + sliced, rN = restrict((120, 240), (0, 360), 360) @test sliced == (120, 240) @test rN == 120 end From 85681ae09e04d6c02d77a2575e88b578c4d61383 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Wed, 29 Apr 2026 09:48:42 +0200 Subject: [PATCH 37/44] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e33fb9936..307c1f3a8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NumericalEarth" uuid = "904d977b-046a-4731-8b86-9235c0d1ef02" license = "MIT" -version = "0.3.0" +version = "0.3.1" authors = ["NumericalEarth contributors"] [deps] From 3fa6f53b6241199771f447e156c1e17e6aca277f Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Wed, 29 Apr 2026 18:00:40 +0200 Subject: [PATCH 38/44] remove extra comment --- src/DataWrangling/JRA55/JRA55_prescribed_land.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/DataWrangling/JRA55/JRA55_prescribed_land.jl b/src/DataWrangling/JRA55/JRA55_prescribed_land.jl index 28aa6d2ea..f20672adb 100644 --- a/src/DataWrangling/JRA55/JRA55_prescribed_land.jl +++ b/src/DataWrangling/JRA55/JRA55_prescribed_land.jl @@ -17,11 +17,7 @@ JRA55PrescribedLand(arch::Distributed; kw...) = other_kw...) Return a [`PrescribedLand`](@ref) representing JRA55 reanalysis land surface data -(river runoff and iceberg calving freshwater fluxes). Each freshwater field is -constructed via `FieldTimeSeries(::JRA55Metadata)`, which uses a `DatasetBackend` -parameterised by JRA55 metadata so that the JRA55-specific `set!` (chunked-yearly -NetCDF) is dispatched. The `region` keyword restricts the land fields to a sub-domain -of the global JRA55 grid. +(river runoff and iceberg calving freshwater fluxes). """ function JRA55PrescribedLand(architecture = CPU(); dataset = RepeatYearJRA55(), From 958cffa34f5450b1c25a5159cc715a4c8f63ad18 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Wed, 29 Apr 2026 19:00:27 +0200 Subject: [PATCH 39/44] fix test surface fluxes --- test/test_surface_fluxes.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_surface_fluxes.jl b/test/test_surface_fluxes.jl index bf1b4d687..632fbbdd6 100644 --- a/test/test_surface_fluxes.jl +++ b/test/test_surface_fluxes.jl @@ -196,7 +196,7 @@ end set!(ocean_with_land.model, T = 15, S = 30) land_dates = all_dates(RepeatYearJRA55(), :river_freshwater_flux) - land = JRA55PrescribedLand(arch; end_date=land_dates[2], backend = InMemory()) + land = JRA55PrescribedLand(arch; end_date=land_dates[2]) model_with_land = OceanOnlyModel(ocean_with_land; atmosphere, land) # Verify land exchanger is wired up From f69483b6c4e0f0aa6975595fe06eb5243deb5fb0 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Wed, 29 Apr 2026 20:50:46 +0200 Subject: [PATCH 40/44] fix the tests --- src/DataWrangling/metadata_field.jl | 31 ++++++++++++++--------------- test/test_ocean_only_model.jl | 3 ++- test/test_ocean_sea_ice_model.jl | 3 ++- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/DataWrangling/metadata_field.jl b/src/DataWrangling/metadata_field.jl index 3eb844da4..1d178a4d5 100644 --- a/src/DataWrangling/metadata_field.jl +++ b/src/DataWrangling/metadata_field.jl @@ -180,24 +180,23 @@ function Field(metadata::Metadatum, arch=CPU(); if !isnothing(inpainting) inpainted_path = inpainted_metadata_path(metadata) if isfile(inpainted_path) - file = jldopen(inpainted_path, "r") - maxiter = file["inpainting_maxiter"] - - # read data if generated with the same inpainting - if maxiter == inpainting.maxiter - data = file["data"] - close(file) - try - copyto!(parent(field), data) - return field - catch - @warn "Could not load existing inpainted data at $inpainted_path.\n" * - "Re-inpainting and saving data..." - rm(inpainted_path, force=true) + # apply a load guard for corrupted files + loaded = false + try + jldopen(inpainted_path, "r") do file + if haskey(file, "inpainting_maxiter") && + file["inpainting_maxiter"] == inpainting.maxiter + copyto!(parent(field), file["data"]) + loaded = true + end end + catch err + @warn "Could not load existing inpainted data at $inpainted_path; " * + "re-inpainting and saving data..." exception=err + rm(inpainted_path, force=true) + loaded = false end - - close(file) + loaded && return field end end diff --git a/test/test_ocean_only_model.jl b/test/test_ocean_only_model.jl index a7777057e..35f2f7211 100644 --- a/test/test_ocean_only_model.jl +++ b/test/test_ocean_only_model.jl @@ -62,7 +62,8 @@ using Oceananigans.OrthogonalSphericalShellGrids ##### @info "Testing OceanOnlyModel with JRA55PrescribedLand on $A..." - land = JRA55PrescribedLand(arch; backend) + land_dates = all_dates(RepeatYearJRA55(), :river_freshwater_flux) + land = JRA55PrescribedLand(arch; end_date=land_dates[2]) @test begin ocean_with_land = ocean_simulation(grid; free_surface) diff --git a/test/test_ocean_sea_ice_model.jl b/test/test_ocean_sea_ice_model.jl index 167df37c7..268ba339f 100644 --- a/test/test_ocean_sea_ice_model.jl +++ b/test/test_ocean_sea_ice_model.jl @@ -70,7 +70,8 @@ using ClimaSeaIce.Rheologies # Test with land component @info "Testing OceanSeaIceModel with land on $A..." - land = JRA55PrescribedLand(arch; backend) + land_dates = all_dates(RepeatYearJRA55(), :river_freshwater_flux) + land = JRA55PrescribedLand(arch; end_date=land_dates[2]) @test begin ocean_with_land = ocean_simulation(grid; free_surface) From 521eb0a64007f5e1d9ba9f7eb2b5247101b64a0b Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Fri, 1 May 2026 10:09:49 +0200 Subject: [PATCH 41/44] fix tests --- test/test_column_field.jl | 64 +++++++++++++++++++++++++++------------ 1 file changed, 45 insertions(+), 19 deletions(-) diff --git a/test/test_column_field.jl b/test/test_column_field.jl index ac4932b64..a30c2e876 100644 --- a/test/test_column_field.jl +++ b/test/test_column_field.jl @@ -199,8 +199,9 @@ end grid = native_grid(md) @test grid isa RectilinearGrid - @test topology(grid) == (Flat, Flat, Bounded) - # ERA5 has z = (0, 1), single level + # ERA5HourlySingleLevel is a 2D surface dataset — Column grids collapse + # the (absent) vertical axis to Flat as well. + @test topology(grid) == (Flat, Flat, Flat) @test size(grid) == (1, 1, 1) end @@ -217,32 +218,57 @@ end @testset "restrict (BoundingBox grid construction helper)" begin restrict = NumericalEarth.DataWrangling.restrict - # Identity case: bbox covers the full domain. Grid pads by Δ/2 on each side - # so that face midpoints land on data centers. The padded extent is N+1 - # cells of width Δ exactly, but Float64 rounding can push the ceil one cell - # past that — so allow [N+1, N+2]. + # Uniform native grid (interfaces given as a 2-tuple of endpoints): + # `restrict` keeps the bbox endpoints verbatim and assigns a cell count + # proportional to the bbox extent. + + # Identity case: bbox covers the full domain → endpoints unchanged, all cells kept. grid_interfaces, rN = restrict((0.0, 360.0), (0.0, 360.0), 1440) - @test grid_interfaces[1] ≈ -0.125 - @test grid_interfaces[2] ≈ 360.125 - @test 1441 <= rN <= 1442 + @test grid_interfaces == (0.0, 360.0) + @test rN == 1440 - # Half-domain bbox: rN should be just over half of N. - _, rN = restrict((0.0, 180.0), (0.0, 360.0), 1440) - @test 720 < rN <= 722 # ceil(0.5 * 1440 + small) = 721 + # Half-domain bbox: round(0.5 * 1440) = 720. + grid_interfaces, rN = restrict((0.0, 180.0), (0.0, 360.0), 1440) + @test grid_interfaces == (0.0, 180.0) + @test rN == 720 - # Small bbox (5° wide on a 1440-cell grid): rN should be ceil(20 + small) = 21. + # Small bbox (5° wide on a 1440-cell, 360°-tall grid): round(5/360 * 1440) = 20. grid_interfaces, rN = restrict((0.0, 5.0), (0.0, 360.0), 1440) - @test grid_interfaces[1] ≈ -0.125 - @test grid_interfaces[2] ≈ 5.125 - @test rN == 21 + @test grid_interfaces == (0.0, 5.0) + @test rN == 20 # Off-origin bbox preserves width: 5° wide on a 720-cell, 180°-tall grid → - # rΔ = 5° + Δ = 5.25°, rN = ceil((5.25/180) * 720) = 21. - _, rN_off = restrict((40.0, 45.0), (-90.0, 90.0), 720) - @test rN_off == 21 + # round(5/180 * 720) = 20. + grid_interfaces, rN_off = restrict((40.0, 45.0), (-90.0, 90.0), 720) + @test grid_interfaces == (40.0, 45.0) + @test rN_off == 20 + + # Sub-cell bbox: cell count is clamped to a minimum of 1 (never 0). + _, rN_tiny = restrict((0.0, 0.01), (0.0, 360.0), 1440) + @test rN_tiny == 1 + + # Stretched native grid (interfaces given as a Vector): `restrict` snaps + # outward to the nearest native cell interfaces and returns the slice. + interfaces = collect(0.0:1.0:10.0) # 10 cells, faces at 0, 1, …, 10 + + # Bbox aligned with native faces — slice exact. + grid_interfaces, rN = restrict((2.0, 5.0), interfaces, 10) + @test grid_interfaces == [2.0, 3.0, 4.0, 5.0] + @test rN == 3 + + # Bbox between native faces — snaps outward (encloses the bbox). + grid_interfaces, rN = restrict((2.3, 4.7), interfaces, 10) + @test grid_interfaces == [2.0, 3.0, 4.0, 5.0] + @test rN == 3 + + # Bbox at the very start — does not underflow past index 1. + grid_interfaces, rN = restrict((-1.0, 1.5), interfaces, 10) + @test first(grid_interfaces) == 0.0 + @test rN >= 1 # Pass-through for `nothing` (the no-restriction case). @test restrict(nothing, (0.0, 360.0), 1440) == ((0.0, 360.0), 1440) + @test restrict(nothing, interfaces, 10) == (interfaces, 10) end @testset "restrict_location dispatch" begin From 8c6d701a256ab049c25f396cc57837776d025c51 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Fri, 1 May 2026 10:57:57 +0200 Subject: [PATCH 42/44] fix restrict to snap out --- src/DataWrangling/metadata_field.jl | 15 +++++--- test/test_column_field.jl | 54 ++++++++++++----------------- 2 files changed, 33 insertions(+), 36 deletions(-) diff --git a/src/DataWrangling/metadata_field.jl b/src/DataWrangling/metadata_field.jl index be1b920ab..5d7c9f0ec 100644 --- a/src/DataWrangling/metadata_field.jl +++ b/src/DataWrangling/metadata_field.jl @@ -25,12 +25,17 @@ restrict(::Nothing, interfaces, N) = interfaces, N restrict(::Nothing, interfaces::NTuple{2,Any}, N) = interfaces, N restrict(::Nothing, interfaces::AbstractVector, N) = interfaces, N -# Uniform native grid: keep the bbox endpoints verbatim with a proportional cell count. +# Snap bbox outward to native cell faces so restricted centers land on native centers function restrict(bbox_interfaces, interfaces::NTuple{2,Any}, N) - extent = interfaces[2] - interfaces[1] - rΔ = bbox_interfaces[2] - bbox_interfaces[1] - rN = max(round(Int, rΔ / extent * N), 1) - return bbox_interfaces, rN + left, right = interfaces + Δ = (right - left) / N + i⁻ = max(floor(Int, (bbox_interfaces[1] - left) / Δ), 0) + i⁺ = min(ceil( Int, (bbox_interfaces[2] - left) / Δ), N) + if i⁺ <= i⁻ + i⁺ = min(i⁻ + 1, N) + i⁻ = max(i⁺ - 1, 0) + end + return (left + i⁻ * Δ, left + i⁺ * Δ), i⁺ - i⁻ end # Stretched native grid: snap outward to the nearest native cell interfaces. diff --git a/test/test_column_field.jl b/test/test_column_field.jl index a30c2e876..27e99507d 100644 --- a/test/test_column_field.jl +++ b/test/test_column_field.jl @@ -218,55 +218,47 @@ end @testset "restrict (BoundingBox grid construction helper)" begin restrict = NumericalEarth.DataWrangling.restrict - # Uniform native grid (interfaces given as a 2-tuple of endpoints): - # `restrict` keeps the bbox endpoints verbatim and assigns a cell count - # proportional to the bbox extent. + # Uniform: snap bbox outward to native faces (Δ = 0.25 throughout). + grid_interfaces, rN = restrict((0.0, 5.0), (-0.125, 359.875), 1440) + @test grid_interfaces == (-0.125, 5.125) + @test rN == 21 - # Identity case: bbox covers the full domain → endpoints unchanged, all cells kept. - grid_interfaces, rN = restrict((0.0, 360.0), (0.0, 360.0), 1440) - @test grid_interfaces == (0.0, 360.0) - @test rN == 1440 + # Already face-aligned — snap is a no-op. + grid_interfaces, rN = restrict((-0.125, 5.125), (-0.125, 359.875), 1440) + @test grid_interfaces == (-0.125, 5.125) + @test rN == 21 - # Half-domain bbox: round(0.5 * 1440) = 720. - grid_interfaces, rN = restrict((0.0, 180.0), (0.0, 360.0), 1440) - @test grid_interfaces == (0.0, 180.0) - @test rN == 720 - - # Small bbox (5° wide on a 1440-cell, 360°-tall grid): round(5/360 * 1440) = 20. - grid_interfaces, rN = restrict((0.0, 5.0), (0.0, 360.0), 1440) - @test grid_interfaces == (0.0, 5.0) - @test rN == 20 - - # Off-origin bbox preserves width: 5° wide on a 720-cell, 180°-tall grid → - # round(5/180 * 720) = 20. grid_interfaces, rN_off = restrict((40.0, 45.0), (-90.0, 90.0), 720) @test grid_interfaces == (40.0, 45.0) @test rN_off == 20 - # Sub-cell bbox: cell count is clamped to a minimum of 1 (never 0). - _, rN_tiny = restrict((0.0, 0.01), (0.0, 360.0), 1440) - @test rN_tiny == 1 + grid_interfaces, rN = restrict((0.0, 360.0), (0.0, 360.0), 1440) + @test grid_interfaces == (0.0, 360.0) + @test rN == 1440 + + # Sub-cell bbox still spans ≥ 1 native cell. + grid_interfaces, rN_tiny = restrict((0.1, 0.15), (-0.125, 359.875), 1440) + @test grid_interfaces == (-0.125, 0.375) + @test rN_tiny == 2 - # Stretched native grid (interfaces given as a Vector): `restrict` snaps - # outward to the nearest native cell interfaces and returns the slice. - interfaces = collect(0.0:1.0:10.0) # 10 cells, faces at 0, 1, …, 10 + # Past-bounds bbox is clamped to native faces. + grid_interfaces, rN = restrict((-1.0, 1.5), (-0.125, 359.875), 1440) + @test grid_interfaces[1] == -0.125 + @test grid_interfaces[2] == 1.625 + @test rN == 7 - # Bbox aligned with native faces — slice exact. + # Stretched (Vector) interfaces. + interfaces = collect(0.0:1.0:10.0) grid_interfaces, rN = restrict((2.0, 5.0), interfaces, 10) @test grid_interfaces == [2.0, 3.0, 4.0, 5.0] @test rN == 3 - - # Bbox between native faces — snaps outward (encloses the bbox). grid_interfaces, rN = restrict((2.3, 4.7), interfaces, 10) @test grid_interfaces == [2.0, 3.0, 4.0, 5.0] @test rN == 3 - - # Bbox at the very start — does not underflow past index 1. grid_interfaces, rN = restrict((-1.0, 1.5), interfaces, 10) @test first(grid_interfaces) == 0.0 @test rN >= 1 - # Pass-through for `nothing` (the no-restriction case). @test restrict(nothing, (0.0, 360.0), 1440) == ((0.0, 360.0), 1440) @test restrict(nothing, interfaces, 10) == (interfaces, 10) end From 568a9f44c3e9492adb29f3c6ab78dd9bd94ed032 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Fri, 1 May 2026 16:26:06 +0200 Subject: [PATCH 43/44] fix the BBox longitudes into native longitudes --- src/DataWrangling/metadata_field.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/DataWrangling/metadata_field.jl b/src/DataWrangling/metadata_field.jl index 5d7c9f0ec..c208490a2 100644 --- a/src/DataWrangling/metadata_field.jl +++ b/src/DataWrangling/metadata_field.jl @@ -79,9 +79,13 @@ function construct_native_grid(metadata, bbox::BoundingBox, arch; halo) native_longitude = longitude_interfaces(metadata) native_latitude = latitude_interfaces(metadata) + # Map the bbox into the native longitude convention + bbox_λ⁻ = convert_to_λ₀_λ₀_plus360(bbox.longitude[1], native_longitude[1]) + bbox_λ⁺ = bbox_λ⁻ + (bbox.longitude[2] - bbox.longitude[1]) + Nx, Ny, Nz = size(metadata) - longitude, Nx = restrict(bbox.longitude, native_longitude, Nx) - latitude, Ny = restrict(bbox.latitude, native_latitude, Ny) + longitude, Nx = restrict((bbox_λ⁻, bbox_λ⁺), native_longitude, Nx) + latitude, Ny = restrict(bbox.latitude, native_latitude, Ny) TX = infer_longitudinal_topology(native_longitude, longitude) From 5d1e95da5656c53e9cd9b216492fe375fa5c85a3 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Fri, 1 May 2026 17:19:00 +0200 Subject: [PATCH 44/44] fix nothing longitudes --- src/DataWrangling/metadata_field.jl | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/DataWrangling/metadata_field.jl b/src/DataWrangling/metadata_field.jl index c208490a2..004095d3a 100644 --- a/src/DataWrangling/metadata_field.jl +++ b/src/DataWrangling/metadata_field.jl @@ -46,6 +46,14 @@ function restrict(bbox_interfaces, interfaces::AbstractVector, N) return interfaces[i⁻:i⁺], rN end +native_convention_longitude(::Nothing, native) = nothing + +# Map a bbox longitude into the native longitude convention +function native_convention_longitude(bbox_longitude, native) + λ⁻ = convert_to_λ₀_λ₀_plus360(bbox_longitude[1], native[1]) + return (λ⁻, λ⁻ + (bbox_longitude[2] - bbox_longitude[1])) +end + """ native_grid(metadata::Metadata, arch=CPU(); halo = (3, 3, 3)) @@ -79,13 +87,12 @@ function construct_native_grid(metadata, bbox::BoundingBox, arch; halo) native_longitude = longitude_interfaces(metadata) native_latitude = latitude_interfaces(metadata) - # Map the bbox into the native longitude convention - bbox_λ⁻ = convert_to_λ₀_λ₀_plus360(bbox.longitude[1], native_longitude[1]) - bbox_λ⁺ = bbox_λ⁻ + (bbox.longitude[2] - bbox.longitude[1]) + # Map the bbox into the native longitude convention. + bbox_lon = native_convention_longitude(bbox.longitude, native_longitude) Nx, Ny, Nz = size(metadata) - longitude, Nx = restrict((bbox_λ⁻, bbox_λ⁺), native_longitude, Nx) - latitude, Ny = restrict(bbox.latitude, native_latitude, Ny) + longitude, Nx = restrict(bbox_lon, native_longitude, Nx) + latitude, Ny = restrict(bbox.latitude, native_latitude, Ny) TX = infer_longitudinal_topology(native_longitude, longitude)