diff --git a/Project.toml b/Project.toml index ab7f75323..75d099f68 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NumericalEarth" uuid = "904d977b-046a-4731-8b86-9235c0d1ef02" license = "MIT" -version = "0.3.0" +version = "0.3.1" authors = ["NumericalEarth contributors"] [deps] diff --git a/examples/generate_surface_fluxes.jl b/examples/generate_surface_fluxes.jl index 3043bef0f..3c7ac17b0 100644 --- a/examples/generate_surface_fluxes.jl +++ b/examples/generate_surface_fluxes.jl @@ -45,9 +45,9 @@ save("ECCO_continents.png", fig) # - downwelling longwave radiation # # We load in memory only the first two time indices, corresponding to January 1st -# (at 00:00 AM and 03:00 AM), by using `JRA55NetCDFBackend(2)`. +# (at 00:00 AM and 03:00 AM), by using `time_indices_in_memory = 2`. -atmosphere = JRA55PrescribedAtmosphere(; backend = JRA55NetCDFBackend(2)) +atmosphere = JRA55PrescribedAtmosphere(; time_indices_in_memory = 2) ocean = ocean_simulation(grid, closure=nothing) # Now that we have an atmosphere and ocean, we `set!` the ocean temperature and salinity diff --git a/examples/inspect_JRA55_data.jl b/examples/inspect_JRA55_data.jl index 57d4040cc..1c1dc805f 100644 --- a/examples/inspect_JRA55_data.jl +++ b/examples/inspect_JRA55_data.jl @@ -4,8 +4,11 @@ using Oceananigans using Oceananigans.Units using Printf -Qswt = NumericalEarth.JRA55.JRA55FieldTimeSeries(:downwelling_shortwave_radiation) -rht = NumericalEarth.JRA55.JRA55FieldTimeSeries(:relative_humidity) +using NumericalEarth.DataWrangling: Metadata +using NumericalEarth.JRA55: RepeatYearJRA55 + +Qswt = FieldTimeSeries(Metadata(:downwelling_shortwave_radiation; dataset=RepeatYearJRA55()); time_indices_in_memory=8) +rht = FieldTimeSeries(Metadata(:specific_humidity; dataset=RepeatYearJRA55()); time_indices_in_memory=8) function lonlat2xyz(lons::AbstractVector, lats::AbstractVector) x = [cosd(lat) * cosd(lon) for lon in lons, lat in lats] diff --git a/examples/meridional_heat_transport_ecco.jl b/examples/meridional_heat_transport_ecco.jl index 58b3e7849..3ac60f2c2 100755 --- a/examples/meridional_heat_transport_ecco.jl +++ b/examples/meridional_heat_transport_ecco.jl @@ -43,9 +43,9 @@ set!(ocean.model, T=ecco_temperature, S=ecco_salinity) set!(sea_ice.model, h=ecco_sea_ice_thickness, ℵ=ecco_sea_ice_concentration) radiation = Radiation(arch) -atmosphere = JRA55PrescribedAtmosphere(arch; backend=JRA55NetCDFBackend(80), - include_rivers_and_icebergs = false) -esm = OceanSeaIceModel(ocean, sea_ice; atmosphere, radiation) +atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory = 80) +land = JRA55PrescribedLand(arch; time_indices_in_memory = 80) +esm = OceanSeaIceModel(ocean, sea_ice; atmosphere, land, radiation) simulation = Simulation(esm; Δt=20minutes, stop_time=5*365days) diff --git a/examples/near_global_ocean_simulation.jl b/examples/near_global_ocean_simulation.jl index f57a5e8e0..473707874 100644 --- a/examples/near_global_ocean_simulation.jl +++ b/examples/near_global_ocean_simulation.jl @@ -102,11 +102,10 @@ radiation = Radiation(arch) # The JRA55 dataset provides atmospheric data such as temperature, humidity, and winds # to calculate turbulent fluxes using bulk formulae, see [`InterfaceComputations`](@ref NumericalEarth.EarthSystemModels.InterfaceComputations). # The number of snapshots that are loaded into memory is determined by -# the `backend`. Here, we load 41 snapshots at a time into memory. +# `time_indices_in_memory`. Here, we load 41 snapshots at a time into memory. -jra55_backend = JRA55NetCDFBackend(41) -atmosphere = JRA55PrescribedAtmosphere(arch; backend=jra55_backend) -land = JRA55PrescribedLand(arch; backend=jra55_backend) +atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory = 41) +land = JRA55PrescribedLand(arch; time_indices_in_memory = 41) # ## The coupled simulation diff --git a/examples/one_degree_simulation.jl b/examples/one_degree_simulation.jl index a08bba029..a84fc91a3 100644 --- a/examples/one_degree_simulation.jl +++ b/examples/one_degree_simulation.jl @@ -96,9 +96,8 @@ set!(sea_ice.model, h=ecco_sea_ice_thickness, ℵ=ecco_sea_ice_concentration) # We force the simulation with a JRA55-do atmospheric reanalysis. radiation = Radiation(arch) -jra55_backend = JRA55NetCDFBackend(80) -atmosphere = JRA55PrescribedAtmosphere(arch; backend=jra55_backend) -land = JRA55PrescribedLand(arch; backend=jra55_backend) +atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory = 80) +land = JRA55PrescribedLand(arch; time_indices_in_memory = 80) # ### Coupled simulation diff --git a/examples/single_column_os_papa_simulation.jl b/examples/single_column_os_papa_simulation.jl index d0a38baf6..bf869a57b 100644 --- a/examples/single_column_os_papa_simulation.jl +++ b/examples/single_column_os_papa_simulation.jl @@ -63,10 +63,9 @@ set!(ocean.model, T=Metadatum(:temperature, dataset=GLORYSMonthly(), region=col) # We build a `JRA55PrescribedAtmosphere` at the same location as the single-colunm grid # which is based on the JRA55 reanalysis. -atmosphere = JRA55PrescribedAtmosphere(longitude = λ★, - latitude = φ★, +atmosphere = JRA55PrescribedAtmosphere(region = Column(λ★, φ★), end_date = DateTime(1990, 1, 31), # Last day of the simulation - backend = InMemory()) + time_indices_in_memory = 1000) # This builds a representation of the atmosphere on the small grid diff --git a/examples/veros_ocean_forced_simulation.jl b/examples/veros_ocean_forced_simulation.jl index bb95573f5..65db31b33 100644 --- a/examples/veros_ocean_forced_simulation.jl +++ b/examples/veros_ocean_forced_simulation.jl @@ -68,7 +68,7 @@ ocean.set_forcing = set_forcing_tke_only # This includes 2-meter wind velocity, temperature, humidity, downwelling longwave and shortwave # radiation, as well as freshwater fluxes. -atmos = JRA55PrescribedAtmosphere(; backend = JRA55NetCDFBackend(10)) +atmos = JRA55PrescribedAtmosphere(; time_indices_in_memory = 10) # The coupled ocean--atmosphere model. # We use the default radiation model and we do not couple an ice model for simplicity. diff --git a/experiments/arctic_simulation.jl b/experiments/arctic_simulation.jl index afabdf826..a602fe3b8 100644 --- a/experiments/arctic_simulation.jl +++ b/experiments/arctic_simulation.jl @@ -89,7 +89,7 @@ set!(sea_ice.model, h=Metadatum(:sea_ice_thickness; dataset), ##### A Prescribed Atmosphere model ##### -atmosphere = JRA55PrescribedAtmosphere(arch; backend=JRA55NetCDFBackend(40)) +atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=40) radiation = Radiation() ##### diff --git a/experiments/coupled_simulation/earth_system_coupled_simulation.jl b/experiments/coupled_simulation/earth_system_coupled_simulation.jl index 334ebce91..975602b65 100644 --- a/experiments/coupled_simulation/earth_system_coupled_simulation.jl +++ b/experiments/coupled_simulation/earth_system_coupled_simulation.jl @@ -79,7 +79,7 @@ salinity = ECCOMetadata(:salinity; dir="./") ice_thickness = ECCOMetadata(:sea_ice_thickness; dir="./") ice_concentration = ECCOMetadata(:sea_ice_concentration; dir="./") -atmosphere = JRA55PrescribedAtmosphere(arch, backend=JRA55NetCDFBackend(20)) +atmosphere = JRA55PrescribedAtmosphere(arch, time_indices_in_memory=20) radiation = Radiation(ocean_albedo = LatitudeDependentAlbedo(), sea_ice_albedo=0.6) set!(ocean.model, T=temperature, S=salinity) diff --git a/experiments/flux_climatology/flux_climatology.jl b/experiments/flux_climatology/flux_climatology.jl index 20ad23d60..c041b52e9 100644 --- a/experiments/flux_climatology/flux_climatology.jl +++ b/experiments/flux_climatology/flux_climatology.jl @@ -191,7 +191,7 @@ heat_capacity(ocean::Simulation{<:PrescribedOcean}) = 3995.6 ##### A prescribed atmosphere... ##### -atmosphere = JRA55PrescribedAtmosphere(arch; backend=JRA55NetCDFBackend(1000)) +atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=1000) ##### ##### A prescribed earth... diff --git a/experiments/one_degree_simulation/debug_tides.jl b/experiments/one_degree_simulation/debug_tides.jl index a06e975bf..8c7c544ba 100644 --- a/experiments/one_degree_simulation/debug_tides.jl +++ b/experiments/one_degree_simulation/debug_tides.jl @@ -8,9 +8,8 @@ using Statistics import SPICE -backend = JRA55NetCDFBackend(41) ρᵒᶜ = 1020 -pa = JRA55_field_time_series(:sea_level_pressure; backend) +pa = JRA55_field_time_series(:sea_level_pressure; time_indices_in_memory=41) grid = pa.grid # dt = 30 * 60 # minutes diff --git a/experiments/one_degree_simulation/generate_tidal_forcing.jl b/experiments/one_degree_simulation/generate_tidal_forcing.jl index 6bd301110..ff76cb277 100644 --- a/experiments/one_degree_simulation/generate_tidal_forcing.jl +++ b/experiments/one_degree_simulation/generate_tidal_forcing.jl @@ -6,8 +6,7 @@ using Dates using GLMakie import SPICE -backend = JRA55NetCDFBackend(41) -pa = JRA55_field_time_series(:sea_level_pressure; backend) +pa = JRA55_field_time_series(:sea_level_pressure; time_indices_in_memory=41) pa_with_tides = FieldTimeSeries{Center, Center, Nothing}(pa.grid, pa.times, path = "sea_level_pressure_plus_tides.jld2", diff --git a/experiments/one_degree_simulation/one_degree_simulation.jl b/experiments/one_degree_simulation/one_degree_simulation.jl index e48bc4d56..f2deda0a2 100644 --- a/experiments/one_degree_simulation/one_degree_simulation.jl +++ b/experiments/one_degree_simulation/one_degree_simulation.jl @@ -46,7 +46,7 @@ set!(ocean.model, T=ECCOMetadata(:temperature; dates=first(dates)), S=ECCOMetadata(:salinity; dates=first(dates))) radiation = Radiation(arch) -atmosphere = JRA55PrescribedAtmosphere(arch; backend=JRA55NetCDFBackend(41)) +atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=41) coupled_model = OceanOnlyModel(ocean; atmosphere, radiation) simulation = Simulation(coupled_model; Δt=10minutes, stop_iteration=100) diff --git a/src/DataWrangling/DataWrangling.jl b/src/DataWrangling/DataWrangling.jl index 3741c6e8a..c250e65b9 100644 --- a/src/DataWrangling/DataWrangling.jl +++ b/src/DataWrangling/DataWrangling.jl @@ -200,7 +200,9 @@ default_mask_value(dataset) = NaN # Fundamentals include("metadata.jl") +include("set_region_data.jl") include("metadata_field.jl") +include("dataset_backend.jl") include("metadata_field_time_series.jl") include("inpainting.jl") include("restoring.jl") diff --git a/src/DataWrangling/ECCO/ECCO.jl b/src/DataWrangling/ECCO/ECCO.jl index 576e81014..248f3b808 100644 --- a/src/DataWrangling/ECCO/ECCO.jl +++ b/src/DataWrangling/ECCO/ECCO.jl @@ -33,8 +33,7 @@ using NumericalEarth.DataWrangling: location, compute_mask, inpaint_mask!, - set_metadata_field!, - extract_column! + set_metadata_field! using KernelAbstractions: @kernel, @index @@ -51,6 +50,8 @@ import NumericalEarth.DataWrangling: metaprefix, longitude_interfaces, latitude_interfaces, + longitude_name, + latitude_name, z_interfaces, is_three_dimensional, inpainted_metadata_path, @@ -113,6 +114,13 @@ longitude_interfaces(::ECCODataset) = (0, 360) longitude_interfaces(::ECCO4Monthly) = (-180, 180) latitude_interfaces(::ECCODataset) = (-90, 90) +longitude_name(::Metadata{<:ECCODataset}) = "LONGITUDE_T" +latitude_name(::Metadata{<:ECCODataset}) = "LATITUDE_T" +longitude_name(::Metadata{<:ECCO4Monthly}) = "longitude" +latitude_name(::Metadata{<:ECCO4Monthly}) = "latitude" +longitude_name(::Metadata{<:ECCO4DarwinMonthly}) = "longitude" +latitude_name(::Metadata{<:ECCO4DarwinMonthly}) = "latitude" + z_interfaces(::ECCODataset) = [ -6128.75, -5683.75, @@ -371,37 +379,4 @@ inpainted_metadata_path(metadata::ECCOMetadatum) = joinpath(metadata.dir, inpain include("ECCO_atmosphere.jl") -##### -##### Column Field for ECCO datasets (which always download globally) -##### - -using Oceananigans.BoundaryConditions: fill_halo_regions! - -const ECCOColumnMetadatum = Metadatum{<:ECCODataset, <:Any, <:Column} - -function Oceananigans.Fields.Field(metadata::ECCOColumnMetadatum, arch=CPU(); - inpainting = default_inpainting(metadata), - mask = nothing, - halo = (3, 3, 3), - cache_inpainted_data = true) - - download_dataset(metadata) - column_grid = native_grid(metadata, arch; halo) - - # Build a full-grid Field without a region to load the global data - global_metadatum = Metadatum(metadata.name; - dataset = metadata.dataset, - date = metadata.dates) - - intermediate_field = Field(global_metadatum, arch; inpainting, mask, halo, cache_inpainted_data) - fill_halo_regions!(intermediate_field) - - # Extract the column - _, _, LZ = location(metadata) - column_field = Field{Nothing, Nothing, LZ}(column_grid) - extract_column!(column_field, intermediate_field, metadata.region) - - return column_field -end - end # Module diff --git a/src/DataWrangling/ECCO/ECCO_atmosphere.jl b/src/DataWrangling/ECCO/ECCO_atmosphere.jl index 3481c4e03..50adf50e7 100644 --- a/src/DataWrangling/ECCO/ECCO_atmosphere.jl +++ b/src/DataWrangling/ECCO/ECCO_atmosphere.jl @@ -31,26 +31,19 @@ function ECCOPrescribedAtmosphere(architecture = CPU(), FT = Float32; surface_layer_height = 2, # meters other_kw...) - ua_meta = Metadata(:eastward_wind; dataset, start_date, end_date, dir) - va_meta = Metadata(:northward_wind; dataset, start_date, end_date, dir) - Ta_meta = Metadata(:air_temperature; dataset, start_date, end_date, dir) - qa_meta = Metadata(:air_specific_humidity; dataset, start_date, end_date, dir) - pa_meta = Metadata(:sea_level_pressure; dataset, start_date, end_date, dir) - ℐꜜˡʷ_meta = Metadata(:downwelling_longwave; dataset, start_date, end_date, dir) - ℐꜜˢʷ_meta = Metadata(:downwelling_shortwave; dataset, start_date, end_date, dir) - Fr_meta = Metadata(:rain_freshwater_flux; dataset, start_date, end_date, dir) - - kw = (; time_indices_in_memory, time_indexing) + kw = (; time_indexing, time_indices_in_memory) kw = merge(kw, other_kw) - ua = FieldTimeSeries(ua_meta, architecture; kw...) - va = FieldTimeSeries(va_meta, architecture; kw...) - Ta = FieldTimeSeries(Ta_meta, architecture; kw...) - qa = FieldTimeSeries(qa_meta, architecture; kw...) - pa = FieldTimeSeries(pa_meta, architecture; kw...) - ℐꜜˡʷ = FieldTimeSeries(ℐꜜˡʷ_meta, architecture; kw...) - ℐꜜˢʷ = FieldTimeSeries(ℐꜜˢʷ_meta, architecture; kw...) - Fr = FieldTimeSeries(Fr_meta, architecture; kw...) + ECCOFieldTimeSeries(name) = FieldTimeSeries(Metadata(name; dataset, start_date, end_date, dir), architecture; kw...) + + ua = ECCOFieldTimeSeries(:eastward_wind) + va = ECCOFieldTimeSeries(:northward_wind) + Ta = ECCOFieldTimeSeries(:air_temperature) + qa = ECCOFieldTimeSeries(:air_specific_humidity) + pa = ECCOFieldTimeSeries(:sea_level_pressure) + ℐꜜˡʷ = ECCOFieldTimeSeries(:downwelling_longwave) + ℐꜜˢʷ = ECCOFieldTimeSeries(:downwelling_shortwave) + Fr = ECCOFieldTimeSeries(:rain_freshwater_flux) freshwater_flux = (; rain = Fr) diff --git a/src/DataWrangling/EN4/EN4.jl b/src/DataWrangling/EN4/EN4.jl index 851401439..8e98b1adf 100644 --- a/src/DataWrangling/EN4/EN4.jl +++ b/src/DataWrangling/EN4/EN4.jl @@ -45,6 +45,8 @@ import NumericalEarth.DataWrangling: z_interfaces, longitude_interfaces, latitude_interfaces, + longitude_name, + latitude_name, is_three_dimensional, reversed_vertical_axis, inpainted_metadata_path, @@ -70,6 +72,8 @@ reversed_vertical_axis(::EN4Monthly) = true longitude_interfaces(::EN4Monthly) = (0.5, 360.5) latitude_interfaces(::EN4Monthly) = (-83.5, 89.5) +longitude_name(::Metadata{<:EN4Monthly}) = "lon" +latitude_name(::Metadata{<:EN4Monthly}) = "lat" available_variables(::EN4Monthly) = EN4_dataset_variable_names z_interfaces(::EN4Monthly) = [ diff --git a/src/DataWrangling/JRA55/JRA55.jl b/src/DataWrangling/JRA55/JRA55.jl index 5751b9dcb..92eb8128c 100644 --- a/src/DataWrangling/JRA55/JRA55.jl +++ b/src/DataWrangling/JRA55/JRA55.jl @@ -1,6 +1,6 @@ module JRA55 -export JRA55FieldTimeSeries, JRA55PrescribedAtmosphere, JRA55PrescribedLand, RepeatYearJRA55, MultiYearJRA55 +export JRA55PrescribedAtmosphere, JRA55PrescribedLand, RepeatYearJRA55, MultiYearJRA55, JRA55FieldTimeSeries using Oceananigans using Oceananigans.Units diff --git a/src/DataWrangling/JRA55/JRA55_field_time_series.jl b/src/DataWrangling/JRA55/JRA55_field_time_series.jl index ea3318175..dfd90e635 100644 --- a/src/DataWrangling/JRA55/JRA55_field_time_series.jl +++ b/src/DataWrangling/JRA55/JRA55_field_time_series.jl @@ -1,467 +1,123 @@ using NumericalEarth.DataWrangling: all_dates, native_times using NumericalEarth.DataWrangling: compute_native_date_range +using NumericalEarth.DataWrangling: set_region_data! using Oceananigans.Grids: AbstractGrid using Oceananigans.OutputReaders: PartlyInMemory using Adapt -compute_bounding_nodes(::Nothing, ::Nothing, LH, hnodes) = nothing -compute_bounding_nodes(bounds, ::Nothing, LH, hnodes) = bounds +import NumericalEarth.DataWrangling: retrieve_data -function compute_bounding_nodes(x::Number, ::Nothing, LH, hnodes) - ϵ = convert(typeof(x), 0.001) # arbitrary? - return (x - ϵ, x + ϵ) -end - -# TODO: remove the allowscalar -function compute_bounding_nodes(::Nothing, grid, LH, hnodes) - hg = hnodes(grid, LH()) - h₁ = @allowscalar minimum(hg) - h₂ = @allowscalar maximum(hg) - return h₁, h₂ -end -function compute_bounding_indices(::Nothing, hc) - Nh = length(hc) - return 1, Nh -end +const JRA55NetCDFFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:DatasetBackend{<:Any, <:Any, <:Any, <:JRA55Metadata}} +const JRA55NetCDFFTSRepeatYear = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:DatasetBackend{<:Any, <:Any, <:Any, <:Metadata{<:RepeatYearJRA55}}} +const JRA55NetCDFFTSMultipleYears = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:DatasetBackend{<:Any, <:Any, <:Any, <:Metadata{<:MultiYearJRA55}}} -function compute_bounding_indices(bounds::Tuple, hc) - h₁, h₂ = bounds - Nh = length(hc) +""" + retrieve_data(metadatum::JRA55Metadatum) - # The following should work. If ᵒ are the extrema of nodes we want to - # interpolate to, and the following is a sketch of the JRA55 native grid, - # - # 1 2 3 4 5 - # | | | | | | - # | x ᵒ | x | x | x ᵒ | x | - # | | | | | | - # 1 2 3 4 5 6 - # - # then for example, we should find that (iᵢ, i₂) = (1, 5). - # So we want to reduce the first index by one, and limit them - # both by the available data. There could be some mismatch due - # to the use of different coordinate systems (ie whether λ ∈ (0, 360) - # which we may also need to handle separately. - i₁ = searchsortedfirst(hc, h₁) - i₂ = searchsortedfirst(hc, h₂) - i₁ = max(1, i₁ - 1) - i₂ = min(Nh, i₂) +Read the 2D slice from the JRA55 NetCDF file corresponding to `metadatum`'s single date. +`RepeatYearJRA55` resolves the index from the position within `all_dates` (the file holds +exactly 2920 entries for 1990 and aligns 1:1 with `all_dates`); `MultiYearJRA55` resolves +the index against the file's own `time` axis (one Gregorian-calendar file per year). +""" +function retrieve_data(metadatum::RepeatYearJRA55Metadatum) + path = metadata_path(metadatum) + name = dataset_variable_name(metadatum) - return i₁, i₂ -end + dates = all_dates(metadatum.dataset, metadatum.name) + file_idx = findfirst(==(metadatum.dates), dates) -infer_longitudinal_topology(::Nothing) = Periodic + if isnothing(file_idx) + throw(ArgumentError("Date $(metadatum.dates) not found in $(metadatum.dataset) :$(metadatum.name) all_dates.")) + end -function infer_longitudinal_topology(λbounds) - λ₁, λ₂ = λbounds - TX = λ₂ - λ₁ ≈ 360 ? Periodic : Bounded - return TX + ds = Dataset(path) + data = ds[name][:, :, file_idx] + close(ds) + return data end -function compute_bounding_indices(longitude, latitude, grid, LX, LY, λc, φc) - λbounds = compute_bounding_nodes(longitude, grid, LX, λnodes) - φbounds = compute_bounding_nodes(latitude, grid, LY, φnodes) +function retrieve_data(metadatum::MultiYearJRA55Metadatum) + path = metadata_path(metadatum) + name = dataset_variable_name(metadatum) - i₁, i₂ = compute_bounding_indices(λbounds, λc) - j₁, j₂ = compute_bounding_indices(φbounds, φc) - TX = infer_longitudinal_topology(λbounds) + ds = Dataset(path) + file_dates = ds["time"][:] + file_idx = findfirst(==(metadatum.dates), file_dates) - return i₁, i₂, j₁, j₂, TX -end + if isnothing(file_idx) + close(ds) + throw(ArgumentError(string("Date ", metadatum.dates, " not found in JRA55 multi-year file ", path, "."))) + end -struct JRA55NetCDFBackend{M} <: AbstractInMemoryBackend{Int} - start :: Int - length :: Int - metadata :: M + data = ds[name][:, :, file_idx] + close(ds) + return data end -Adapt.adapt_structure(to, b::JRA55NetCDFBackend) = JRA55NetCDFBackend(b.start, b.length, nothing) - -""" - JRA55NetCDFBackend(length) - -Represents a JRA55 FieldTimeSeries backed by JRA55 native netCDF files. -""" -JRA55NetCDFBackend(length, metadata::Metadata) = JRA55NetCDFBackend(1, length, metadata) -JRA55NetCDFBackend(start::Integer, length::Integer) = JRA55NetCDFBackend(start, length, nothing) - -# Metadata - agnostic constructor -JRA55NetCDFBackend(length) = JRA55NetCDFBackend(1, length, nothing) - -Base.length(backend::JRA55NetCDFBackend) = backend.length -Base.summary(backend::JRA55NetCDFBackend) = string("JRA55NetCDFBackend(", backend.start, ", ", backend.length, ")") - -const JRA55NetCDFFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:JRA55NetCDFBackend} -const JRA55NetCDFFTSRepeatYear = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:JRA55NetCDFBackend{<:Metadata{<:RepeatYearJRA55}}} -const JRA55NetCDFFTSMultipleYears = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:JRA55NetCDFBackend{<:Metadata{<:MultiYearJRA55}}} - -# Note that each file should have the variables -# - ds["time"]: time coordinate -# - ds["lon"]: longitude at the location of the variable -# - ds["lat"]: latitude at the location of the variable -# - ds["lon_bnds"]: bounding longitudes between which variables are averaged -# - ds["lat_bnds"]: bounding latitudes between which variables are averaged -# - ds[shortname]: the variable data +# Split at the wrap point if `nn` cycles past 1 — DiskArrays requires sorted indices. +function jra55_read_data(ds, name, i, j, nn) + if issorted(nn) + return ds[name][i, j, nn] + else + m = findfirst(==(1), nn) + d1 = ds[name][i, j, nn[1:m-1]] + d2 = ds[name][i, j, nn[m:end]] + return cat(d1, d2; dims=3) + end +end -# Simple case, only one file per variable, no need to deal with multiple files function set!(fts::JRA55NetCDFFTSRepeatYear, backend=fts.backend) - metadata = backend.metadata - - filename = metadata.filename - path = joinpath(metadata.dir, filename) - ds = Dataset(path) - - # Nodes at the variable location + ds = Dataset(joinpath(metadata.dir, metadata.filename)) λc = ds["lon"][:] φc = ds["lat"][:] - LX, LY, LZ = location(fts) - i₁, i₂, j₁, j₂, TX = compute_bounding_indices(nothing, nothing, fts.grid, LX, LY, λc, φc) - - nn = time_indices(fts) - nn = collect(nn) - name = dataset_variable_name(fts.backend.metadata) - - if issorted(nn) - data = ds[name][i₁:i₂, j₁:j₂, nn] - else - # The time indices may be cycling past 1; eg ti = [6, 7, 8, 1]. - # However, DiskArrays does not seem to support loading data with unsorted - # indices. So to handle this, we load the data in chunks, where each chunk's - # indices are sorted, and then glue the data together. - m = findfirst(n -> n == 1, nn) - n1 = nn[1:m-1] - n2 = nn[m:end] - - data1 = ds[name][i₁:i₂, j₁:j₂, n1] - data2 = ds[name][i₁:i₂, j₁:j₂, n2] - data = cat(data1, data2, dims=3) - end + nn = collect(time_indices(fts)) + name = dataset_variable_name(metadata) + raw = jra55_read_data(ds, name, :, :, nn) close(ds) + full_data = reshape(raw, length(λc), length(φc), 1, length(nn)) - copyto!(interior(fts, :, :, 1, :), data) + set_region_data!(fts, full_data, λc, φc, metadata) fill_halo_regions!(fts) - return nothing end -# Tricky case: multiple files per variable -- one file per year -- -# we need to infer the file name from the metadata and split the data loading function set!(fts::JRA55NetCDFFTSMultipleYears, backend=fts.backend) - metadata = backend.metadata + name = dataset_variable_name(metadata) - filename = metadata.filename - filename = unique(filename) - name = dataset_variable_name(metadata) - start_date = first_date(metadata.dataset, metadata.name) + ftsn = collect(time_indices(fts)) + slot_dates = metadata.dates[ftsn] + needed_files = unique(getfilename(metadata.filename, n) for n in ftsn) - for file in filename - - path = joinpath(metadata.dir, file) - ds = Dataset(path) - - # This can be simplified once we start supporting a - # datetime `Clock` in Oceananigans + for file in needed_files + ds = Dataset(joinpath(metadata.dir, file)) file_dates = ds["time"][:] - file_indices = 1:length(file_dates) - file_times = zeros(length(file_dates)) - for (t, date) in enumerate(file_dates) - delta = date - start_date - delta = Second(delta).value - file_times[t] = delta - end - - ftsn = time_indices(fts) - ftsn = collect(ftsn) - # Intersect the time indices with the file times - nn = findall(n -> file_times[n] ∈ fts.times[ftsn], file_indices) - ftsn = findall(n -> fts.times[n] ∈ file_times[nn], ftsn) + nn = Int[] + ftsn_loc = Int[] + for (loc, slot_date) in enumerate(slot_dates) + file_idx = findfirst(==(slot_date), file_dates) + if !isnothing(file_idx) + push!(nn, file_idx) + push!(ftsn_loc, loc) + end + end if !isempty(nn) - # Nodes at the variable location λc = ds["lon"][:] φc = ds["lat"][:] - LX, LY, LZ = location(fts) - i₁, i₂, j₁, j₂, TX = compute_bounding_indices(nothing, nothing, fts.grid, LX, LY, λc, φc) - - - if issorted(nn) - data = ds[name][i₁:i₂, j₁:j₂, nn] - else - # The time indices may be cycling past 1; eg ti = [6, 7, 8, 1]. - # However, DiskArrays does not seem to support loading data with unsorted - # indices. So to handle this, we load the data in chunks, where each chunk's - # indices are sorted, and then glue the data together. - m = findfirst(n -> n == 1, nn) - n1 = nn[1:m-1] - n2 = nn[m:end] - - data1 = ds[name][i₁:i₂, j₁:j₂, n1] - data2 = ds[name][i₁:i₂, j₁:j₂, n2] - data = cat(data1, data2, dims=3) - end - - close(ds) - - # We need to set the time index for each file - # Find start index corresponding to the underlying data - for n in 1:length(nn) - copyto!(interior(fts, :, :, 1, ftsn[n]), data[:, :, n]) - end + raw = jra55_read_data(ds, name, :, :, nn) + full_data = reshape(raw, length(λc), length(φc), 1, length(nn)) + set_region_data!(fts, full_data, λc, φc, metadata; slot_indices = ftsn_loc) end + close(ds) end fill_halo_regions!(fts) - return nothing end -new_backend(b::JRA55NetCDFBackend, start, length) = JRA55NetCDFBackend(start, length, b.metadata) - -""" - JRA55FieldTimeSeries(variable_name, architecture=CPU(), FT=Float32; - dataset = RepeatYearJRA55(), - dates = all_JRA55_dates(version), - latitude = nothing, - longitude = nothing, - dir = download_JRA55_cache, - backend = InMemory(), - time_indexing = Cyclical()) - -Return a `FieldTimeSeries` containing atmospheric reanalysis data for `variable_name`, -which describes one of the variables from the Japanese 55-year atmospheric reanalysis -for driving ocean-sea ice models (JRA55-do). The JRA55-do dataset is described by [tsujino2018jra](@citet). - -The `variable_name`s (and their `shortname`s used in the netCDF files) available from the JRA55-do are: -- `:river_freshwater_flux` ("friver") -- `:rain_freshwater_flux` ("prra") -- `:snow_freshwater_flux` ("prsn") -- `:iceberg_freshwater_flux` ("licalvf") -- `:specific_humidity` ("huss") -- `:sea_level_pressure` ("psl") -- `:relative_humidity` ("rhuss") -- `:downwelling_longwave_radiation` ("rlds") -- `:downwelling_shortwave_radiation` ("rsds") -- `:temperature` ("ras") -- `:eastward_velocity` ("uas") -- `:northward_velocity` ("vas") - -Keyword arguments -================= - -- `architecture`: Architecture for the `FieldTimeSeries`. Default: CPU() - -- `dataset`: The data dataset; supported datasets are: `RepeatYearJRA55()` and `MultiYearJRA55()`. - `MultiYearJRA55()` refers to the full length of the JRA55-do dataset; `RepeatYearJRA55()` - refers to the "repeat-year forcing" dataset derived from JRA55-do. Default: `RepeatYearJRA55()`. - - !!! info "Repeat-year forcing" - - For more information about the derivation of the repeat-year forcing dataset, see [stewart2020jra55](@citet). - - The repeat year in `RepeatYearJRA55()` corresponds to May 1st, 1990 - April 30th, 1991. However, the - returned dataset has dates that range from January 1st to December 31st. This implies - that the first 4 months of the `JRA55RepeatYear()` dataset correspond to year 1991 from the JRA55 - reanalysis and the rest 8 months from 1990. - -- `start_date`: The starting date to use for the dataset. Default: `first_date(dataset, variable_name)`. - -- `end_date`: The ending date to use for the dataset. Default: `end_date(dataset, variable_name)`. - -- `dir`: The directory of the data file. Default: `NumericalEarth.JRA55.download_JRA55_cache`. - -- `time_indexing`: The time indexing scheme for the field time series. Default: `Cyclical()`. - -- `latitude`: Guiding latitude bounds for the resulting grid. - Used to slice the data when loading into memory. - Default: nothing, which retains the latitude range of the native grid. - -- `longitude`: Guiding longitude bounds for the resulting grid. - Used to slice the data when loading into memory. - Default: nothing, which retains the longitude range of the native grid. - -- `backend`: Backend for the `FieldTimeSeries`. The two options are: - * `InMemory()`: the whole time series is loaded into memory. - * `JRA55NetCDFBackend(total_time_instances_in_memory)`: only a subset of the time series - is loaded into memory. Default: `InMemory()`. - -References -========== - -- Tsujino et al. (2018). JRA-55 based surface dataset for driving ocean-sea-ice models (JRA55-do), _Ocean Modelling_, **130(1)**, 79-139. - -- Stewart et al. (2020). JRA55-do-based repeat year forcing datasets for driving ocean–sea-ice models, _Ocean Modelling_, **147**, 101557. -""" -function JRA55FieldTimeSeries(variable_name::Symbol, architecture=CPU(), FT=Float32; - dataset = RepeatYearJRA55(), - start_date = first_date(dataset, variable_name), - end_date = last_date(dataset, variable_name), - dir = download_JRA55_cache, - kw...) - - native_dates = all_dates(dataset, variable_name) - dates = compute_native_date_range(native_dates, start_date, end_date) - metadata = Metadata(variable_name; dataset, dates, dir) - - return JRA55FieldTimeSeries(metadata, architecture, FT; kw...) -end - -function JRA55FieldTimeSeries(metadata::JRA55Metadata, architecture=CPU(), FT=Float32; - latitude = nothing, - longitude = nothing, - backend = InMemory(), - time_indexing = Cyclical()) - - # Cannot use `TotallyInMemory` backend with MultiYearJRA55 dataset - if metadata.dataset isa MultiYearJRA55 && backend isa TotallyInMemory - msg = string("The `InMemory` backend is not supported for the MultiYearJRA55 dataset.") - throw(ArgumentError(msg)) - end - - # First thing: we download the dataset! - download_dataset(metadata) - - # Regularize the backend in case of `JRA55NetCDFBackend` - if backend isa JRA55NetCDFBackend - if backend.metadata isa Nothing - backend = JRA55NetCDFBackend(backend.length, metadata) - end - - if backend.length > length(metadata) - backend = JRA55NetCDFBackend(backend.start, length(metadata), metadata) - end - end - - # Unpack metadata details - dataset = metadata.dataset - name = metadata.name - time_indices = JRA55_time_indices(dataset, metadata.dates, name) - - # Change the metadata to reflect the actual time indices - dates = all_dates(dataset, name)[time_indices] - metadata = Metadata(metadata.name; dataset=metadata.dataset, dates, dir=metadata.dir) - - shortname = dataset_variable_name(metadata) - variable_name = metadata.name - - filepath = metadata_path(metadata) # Might be multiple paths!!! - filepath = filepath isa AbstractArray ? first(filepath) : filepath - - # OnDisk backends do not support time interpolation! - # Disallow OnDisk for JRA55 dataset loading - if ((backend isa InMemory) && !isnothing(backend.length)) || backend isa OnDisk - msg = string("We cannot load the JRA55 dataset with a $(backend) backend. Use `InMemory()` or `JRA55NetCDFBackend(N)` instead.") - throw(ArgumentError(msg)) - end - - if !(variable_name ∈ JRA55_variable_names) - variable_strs = Tuple(" - :$name \n" for name in JRA55_variable_names) - variables_msg = prod(variable_strs) - - msg = string("The variable :$variable_name is not provided by the JRA55-do dataset!", '\n', - "The variables provided by the JRA55-do dataset are:", '\n', - variables_msg) - - throw(ArgumentError(msg)) - end - - # Record some important user decisions - totally_in_memory = backend isa TotallyInMemory - - # Determine default time indices - if totally_in_memory - # In this case, the whole time series is in memory. - # Either the time series is short, or we are doing a limited-area - # simulation, like in a single column. So, we conservatively - # set a default `time_indices = 1:2`. - time_indices_in_memory = time_indices - native_fts_architecture = architecture - else - # In this case, part or all of the time series will be stored in a file. - # Note: if the user has provided a grid, we will have to preprocess the - # .nc JRA55 data into a .jld2 file. In this case, `time_indices` refers - # to the time_indices that we will preprocess; - # by default we choose all of them. The architecture is only the - # architecture used for preprocessing, which typically will be CPU() - # even if we would like the final FieldTimeSeries on the GPU. - time_indices_in_memory = 1:length(backend) - native_fts_architecture = architecture - end - - ds = Dataset(filepath) - - # Note that each file should have the variables - # - ds["time"]: time coordinate - # - ds["lon"]: longitude at the location of the variable - # - ds["lat"]: latitude at the location of the variable - # - ds["lon_bnds"]: bounding longitudes between which variables are averaged - # - ds["lat_bnds"]: bounding latitudes between which variables are averaged - # - ds[shortname]: the variable data - - # Nodes at the variable location - λc = ds["lon"][:] - φc = ds["lat"][:] - - # Interfaces for the "native" JRA55 grid - λn = Array(ds["lon_bnds"][1, :]) - φn = Array(ds["lat_bnds"][1, :]) - - # The netCDF coordinates lon_bnds and lat_bnds do not include - # the last interfaces, so we push them here. - push!(φn, 90) - push!(λn, λn[1] + 360) - - i₁, i₂, j₁, j₂, TX = compute_bounding_indices(longitude, latitude, nothing, Center, Center, λc, φc) - - λr = λn[i₁:i₂+1] - φr = φn[j₁:j₂+1] - Nrx = length(λr) - 1 - Nry = length(φr) - 1 - close(ds) - - N = (Nrx, Nry) - H = min.(N, (3, 3)) - - JRA55_native_grid = LatitudeLongitudeGrid(native_fts_architecture, FT; - halo = H, - size = N, - longitude = λr, - latitude = φr, - topology = (TX, Bounded, Flat)) - - boundary_conditions = FieldBoundaryConditions(JRA55_native_grid, (Center(), Center(), nothing)) - start_time = first_date(metadata.dataset, metadata.name) - times = native_times(metadata; start_time) - - if backend isa JRA55NetCDFBackend - fts = FieldTimeSeries{Center, Center, Nothing}(JRA55_native_grid, times; - backend, - time_indexing, - boundary_conditions, - path = filepath, - name = shortname) - - set!(fts) - return fts - else - fts = FieldTimeSeries{Center, Center, Nothing}(JRA55_native_grid, times; - time_indexing, - backend, - boundary_conditions) - - # Fill the data in a GPU-friendly manner - ds = Dataset(filepath) - data = ds[shortname][i₁:i₂, j₁:j₂, time_indices_in_memory] - close(ds) - - copyto!(interior(fts, :, :, 1, :), data) - fill_halo_regions!(fts) - - return fts - end -end diff --git a/src/DataWrangling/JRA55/JRA55_metadata.jl b/src/DataWrangling/JRA55/JRA55_metadata.jl index 53f1fa58c..29bbdab8e 100644 --- a/src/DataWrangling/JRA55/JRA55_metadata.jl +++ b/src/DataWrangling/JRA55/JRA55_metadata.jl @@ -5,27 +5,73 @@ using Downloads using Oceananigans.DistributedComputations using NumericalEarth.DataWrangling -using NumericalEarth.DataWrangling: Metadata, metadata_path, download_progress, AnyDateTime +using NumericalEarth.DataWrangling: Metadata, metadata_path, download_progress, AnyDateTime, DatasetBackend import Dates: year, month, day import Oceananigans.Fields: set! import Base -import NumericalEarth.DataWrangling: all_dates, metadata_filename, build_filename, download_dataset, default_download_directory, available_variables -struct MultiYearJRA55 end -struct RepeatYearJRA55 end +import Oceananigans.Fields: set!, location +import NumericalEarth.DataWrangling: all_dates, + metadata_filename, + build_filename, + download_dataset, + default_download_directory, + dataset_variable_name, + available_variables, + default_inpainting, + getfilename, + longitude_interfaces, + latitude_interfaces, + longitude_name, + latitude_name, + is_three_dimensional + +abstract type JRA55Dataset end + +struct MultiYearJRA55 <: JRA55Dataset end +struct RepeatYearJRA55 <: JRA55Dataset end + +const JRA55Metadata{D} = Metadata{<:JRA55Dataset, D} +const JRA55Metadatum = Metadatum{<:JRA55Dataset} + +const RepeatYearJRA55Metadatum = Metadatum{<:RepeatYearJRA55} +const MultiYearJRA55Metadatum = Metadatum{<:MultiYearJRA55} + +default_download_directory(::JRA55Dataset) = download_JRA55_cache + +function Base.size(::JRA55Dataset, variable) + if variable ∈ [:river_freshwater_flux, :iceberg_freshwater_flux] + (1440, 720, 1) + else + (640, 320, 1) + end +end + +longitude_interfaces(md::JRA55Metadata) = first(jra55_native_interfaces(metadata_path(first(md)))) +latitude_interfaces(md::JRA55Metadata) = last(jra55_native_interfaces(metadata_path(first(md)))) -const JRA55Metadata{D} = Metadata{<:Union{<:MultiYearJRA55, <:RepeatYearJRA55}, D} -const JRA55Metadatum = Metadatum{<:Union{<:MultiYearJRA55, <:RepeatYearJRA55}} +longitude_name(::JRA55Metadata) = "lon" +latitude_name(::JRA55Metadata) = "lat" -default_download_directory(::Union{<:MultiYearJRA55, <:RepeatYearJRA55}) = download_JRA55_cache +# `lon_bnds`/`lat_bnds` hold only the left/lower interface; append the trailing one. +function jra55_native_interfaces(path) + ds = Dataset(path) + λn = Array{Float64}(ds["lon_bnds"][1, :]) + φn = Array{Float64}(ds["lat_bnds"][1, :]) + close(ds) -Base.size(data::JRA55Metadata) = (640, 320, length(data.dates)) -Base.size(::JRA55Metadatum) = (640, 320, 1) + # `lon_bnds` / `lat_bnds` hold only the left/lower interface, + # so we need to append the trailing interface. + push!(λn, λn[1] + 360) + push!(φn, 90) + return λn, φn +end -# JRA55 is a spatially 2D dataset is_three_dimensional(data::JRA55Metadata) = false +default_inpainting(::JRA55Metadata) = nothing + # The whole range of dates in the different dataset datasets # NOTE! rivers and icebergs have a different frequency! (typical JRA55 data is three-hourly while rivers and icebergs are daily) function all_dates(::RepeatYearJRA55, name) @@ -39,7 +85,7 @@ end all_dates(::MultiYearJRA55, name) = JRA55_multiple_year_dates[name] # Fallback, if we not provide the name, take the highest frequency -all_dates(dataset::Union{<:MultiYearJRA55, <:RepeatYearJRA55}) = all_dates(dataset, :temperature) +all_dates(dataset::JRA55Dataset) = all_dates(dataset, :temperature) # Valid for all JRA55 datasets function JRA55_time_indices(dataset, dates, name) @@ -57,11 +103,8 @@ end # File name generation specific to each Dataset dataset # Note that `RepeatYearJRA55` has only one file associated, so the filename # is independent of the date. Override the multi-date fallback to return a plain String. -metadata_filename(::RepeatYearJRA55, name, date, region) = - "RYF." * JRA55_dataset_variable_names[name] * ".1990_1991.nc" - -build_filename(::RepeatYearJRA55, name, dates::AbstractArray, region) = - "RYF." * JRA55_dataset_variable_names[name] * ".1990_1991.nc" +metadata_filename(::RepeatYearJRA55, name, date, region) = "RYF." * JRA55_dataset_variable_names[name] * ".1990_1991.nc" +build_filename(::RepeatYearJRA55, name, dates::AbstractArray, region) = "RYF." * JRA55_dataset_variable_names[name] * ".1990_1991.nc" function metadata_filename(::MultiYearJRA55, name, date, region) shortname = JRA55_dataset_variable_names[name] @@ -84,8 +127,9 @@ end # Convenience functions dataset_variable_name(data::JRA55Metadata) = JRA55_dataset_variable_names[data.name] -available_variables(::MultiYearJRA55) = JRA55_variable_names -available_variables(::RepeatYearJRA55) = JRA55_variable_names +location(::JRA55Metadata) = (Center, Center, Center) + +available_variables(::JRA55Dataset) = JRA55_variable_names # A list of all variables provided in the JRA55 dataset: JRA55_variable_names = (:river_freshwater_flux, diff --git a/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl b/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl index ae395238d..165302e10 100644 --- a/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl +++ b/src/DataWrangling/JRA55/JRA55_prescribed_atmosphere.jl @@ -1,44 +1,50 @@ const AA = Oceananigans.Architectures.AbstractArchitecture -JRA55PrescribedAtmosphere(arch::Distributed, FT = Float32; kw...) = +JRA55PrescribedAtmosphere(arch::Distributed; kw...) = JRA55PrescribedAtmosphere(child_architecture(arch); kw...) - """ - JRA55PrescribedAtmosphere([architecture = CPU(), FT = Float32]; + JRA55PrescribedAtmosphere([architecture = CPU()]; dataset = RepeatYearJRA55(), start_date = first_date(dataset, :temperature), end_date = last_date(dataset, :temperature), - backend = JRA55NetCDFBackend(10), + dir = download_JRA55_cache, + time_indices_in_memory = 10, time_indexing = Cyclical(), surface_layer_height = 10, # meters + region = nothing, other_kw...) -Return a [`PrescribedAtmosphere`](@ref) representing JRA55 reanalysis data. -The atmospheric data will be held in `JRA55FieldTimeSeries` objects containing. -For a detailed description of the keyword arguments, see the [`JRA55FieldTimeSeries`](@ref) constructor. +Return a [`PrescribedAtmosphere`](@ref) representing JRA55 reanalysis data. Each atmospheric field is constructed via +`FieldTimeSeries(::JRA55Metadata)`, which uses a `DatasetBackend` parameterised by JRA55 metadata so that the JRA55-specific +`set!` (chunked-yearly NetCDF) is dispatched. +The `region` keyword restricts the atmosphere to a sub-domain of the global JRA55 grid. """ -function JRA55PrescribedAtmosphere(architecture = CPU(), FT = Float32; +function JRA55PrescribedAtmosphere(architecture = CPU(); dataset = RepeatYearJRA55(), start_date = first_date(dataset, :temperature), end_date = last_date(dataset, :temperature), - backend = JRA55NetCDFBackend(10), + dir = download_JRA55_cache, + time_indices_in_memory = 10, time_indexing = Cyclical(), surface_layer_height = 10, # meters + region = nothing, other_kw...) - kw = (; time_indexing, backend, start_date, end_date, dataset) + kw = (; time_indexing, time_indices_in_memory) kw = merge(kw, other_kw) - ua = JRA55FieldTimeSeries(:eastward_velocity, architecture, FT; kw...) - va = JRA55FieldTimeSeries(:northward_velocity, architecture, FT; kw...) - Ta = JRA55FieldTimeSeries(:temperature, architecture, FT; kw...) - qa = JRA55FieldTimeSeries(:specific_humidity, architecture, FT; kw...) - pa = JRA55FieldTimeSeries(:sea_level_pressure, architecture, FT; kw...) - Fra = JRA55FieldTimeSeries(:rain_freshwater_flux, architecture, FT; kw...) - Fsn = JRA55FieldTimeSeries(:snow_freshwater_flux, architecture, FT; kw...) - ℐꜜˡʷ = JRA55FieldTimeSeries(:downwelling_longwave_radiation, architecture, FT; kw...) - ℐꜜˢʷ = JRA55FieldTimeSeries(:downwelling_shortwave_radiation, architecture, FT; kw...) + JRA55FieldTimeSeries(name) = FieldTimeSeries(Metadata(name; dataset, start_date, end_date, dir, region), architecture; kw...) + + ua = JRA55FieldTimeSeries(:eastward_velocity) + va = JRA55FieldTimeSeries(:northward_velocity) + Ta = JRA55FieldTimeSeries(:temperature) + qa = JRA55FieldTimeSeries(:specific_humidity) + pa = JRA55FieldTimeSeries(:sea_level_pressure) + Fra = JRA55FieldTimeSeries(:rain_freshwater_flux) + Fsn = JRA55FieldTimeSeries(:snow_freshwater_flux) + ℐꜜˡʷ = JRA55FieldTimeSeries(:downwelling_longwave_radiation) + ℐꜜˢʷ = JRA55FieldTimeSeries(:downwelling_shortwave_radiation) freshwater_flux = (rain = Fra, snow = Fsn) diff --git a/src/DataWrangling/JRA55/JRA55_prescribed_land.jl b/src/DataWrangling/JRA55/JRA55_prescribed_land.jl index 1a42dca63..f20672adb 100644 --- a/src/DataWrangling/JRA55/JRA55_prescribed_land.jl +++ b/src/DataWrangling/JRA55/JRA55_prescribed_land.jl @@ -2,34 +2,40 @@ using NumericalEarth.Lands: PrescribedLand export JRA55PrescribedLand -JRA55PrescribedLand(arch::Distributed, FT=Float32; kw...) = - JRA55PrescribedLand(child_architecture(arch), FT; kw...) +JRA55PrescribedLand(arch::Distributed; kw...) = + JRA55PrescribedLand(child_architecture(arch); kw...) """ - JRA55PrescribedLand([architecture = CPU(), FT = Float32]; + JRA55PrescribedLand([architecture = CPU()]; dataset = RepeatYearJRA55(), start_date = first_date(dataset, :river_freshwater_flux), end_date = last_date(dataset, :river_freshwater_flux), - backend = JRA55NetCDFBackend(10), + dir = download_JRA55_cache, + time_indices_in_memory = 10, time_indexing = Cyclical(), + region = nothing, other_kw...) Return a [`PrescribedLand`](@ref) representing JRA55 reanalysis land surface data -(river runoff and iceberg calving freshwater fluxes). +(river runoff and iceberg calving freshwater fluxes). """ -function JRA55PrescribedLand(architecture=CPU(), FT=Float32; +function JRA55PrescribedLand(architecture = CPU(); dataset = RepeatYearJRA55(), start_date = first_date(dataset, :river_freshwater_flux), end_date = last_date(dataset, :river_freshwater_flux), - backend = JRA55NetCDFBackend(10), + dir = download_JRA55_cache, + time_indices_in_memory = 10, time_indexing = Cyclical(), + region = nothing, other_kw...) - kw = (; time_indexing, backend, start_date, end_date, dataset) + kw = (; time_indexing, time_indices_in_memory) kw = merge(kw, other_kw) - Fri = JRA55FieldTimeSeries(:river_freshwater_flux, architecture, FT; kw...) - Fic = JRA55FieldTimeSeries(:iceberg_freshwater_flux, architecture, FT; kw...) + JRA55FieldTimeSeries(name) = FieldTimeSeries(Metadata(name; dataset, start_date, end_date, dir, region), architecture; kw...) + + Fri = JRA55FieldTimeSeries(:river_freshwater_flux) + Fic = JRA55FieldTimeSeries(:iceberg_freshwater_flux) freshwater_flux = (; rivers = Fri, icebergs = Fic) diff --git a/src/DataWrangling/WOA/WOA.jl b/src/DataWrangling/WOA/WOA.jl index 013b4371e..4fa709804 100644 --- a/src/DataWrangling/WOA/WOA.jl +++ b/src/DataWrangling/WOA/WOA.jl @@ -33,6 +33,8 @@ import NumericalEarth.DataWrangling: z_interfaces, longitude_interfaces, latitude_interfaces, + longitude_name, + latitude_name, is_three_dimensional, reversed_vertical_axis, inpainted_metadata_path, @@ -89,6 +91,8 @@ reversed_vertical_axis(::WOAClimatology) = true longitude_interfaces(::WOAClimatology) = (-180, 180) latitude_interfaces(::WOAClimatology) = (-90, 90) +longitude_name(::Metadata{<:WOAClimatology}) = "lon" +latitude_name(::Metadata{<:WOAClimatology}) = "lat" available_variables(::WOAClimatology) = WOA_variable_names """ diff --git a/src/DataWrangling/dataset_backend.jl b/src/DataWrangling/dataset_backend.jl new file mode 100644 index 000000000..fa1dac439 --- /dev/null +++ b/src/DataWrangling/dataset_backend.jl @@ -0,0 +1,100 @@ +using Oceananigans.OutputReaders: Cyclical, AbstractInMemoryBackend, FlavorOfFTS, time_indices + +import Oceananigans.OutputReaders: new_backend +import Oceananigans.Fields: set! + +@inline instantiate(T::DataType) = T() +@inline instantiate(T) = T + +""" + DatasetBackend{N, C, I, M} <: AbstractInMemoryBackend{Int} + +In-memory backend for a `FieldTimeSeries` backed by a dataset whose metadata +maps each in-memory time index to a file (or subset of a file) on disk. The +backend carries + +- `start`, `length`: sliding-window extents into the metadata +- `inpainting`: inpainting algorithm used when reading per-file datasets + (e.g. `NearestNeighborInpainting`); `nothing` for datasets whose native + NetCDF already covers the whole target grid (e.g. JRA55). +- `metadata`: the dataset metadata — its dataset type parameterises `set!` + dispatch, so per-file and chunked (multi-time-per-file) datasets can + coexist under the same backend. + +Type parameters `N` (on-native-grid) and `C` (cache-inpainted-data) are +flags hoisted into the type so that dispatch and `Adapt.adapt_structure` +can act on them without allocating. +""" +struct DatasetBackend{N, C, I, M} <: AbstractInMemoryBackend{Int} + start :: Int + length :: Int + inpainting :: I + metadata :: M + + function DatasetBackend{N, C}(start::Int, length::Int, inpainting, metadata) where {N, C} + M = typeof(metadata) + I = typeof(inpainting) + return new{N, C, I, M}(start, length, inpainting, metadata) + end +end + +Adapt.adapt_structure(to, b::DatasetBackend{N, C}) where {N, C} = + DatasetBackend{N, C}(b.start, b.length, nothing, nothing) + +""" + DatasetBackend(length, metadata; + on_native_grid = false, + cache_inpainted_data = false, + inpainting = NearestNeighborInpainting(Inf)) + DatasetBackend(start, length, metadata; ...) + +Construct a `DatasetBackend` holding `length` in-memory time indices starting +at `start` (default `1`). `inpainting = nothing` selects the dispatch path +for datasets whose files hold multiple time instances (e.g. chunked NetCDF +like JRA55), where per-file inpainting is not applicable. +""" +function DatasetBackend(length, metadata; + on_native_grid = false, + cache_inpainted_data = false, + inpainting = NearestNeighborInpainting(Inf)) + + return DatasetBackend{on_native_grid, cache_inpainted_data}(1, length, inpainting, metadata) +end + +function DatasetBackend(start::Integer, length::Integer, metadata; + on_native_grid = false, + cache_inpainted_data = false, + inpainting = NearestNeighborInpainting(Inf)) + + return DatasetBackend{on_native_grid, cache_inpainted_data}(start, length, inpainting, metadata) +end + +Base.length(backend::DatasetBackend) = backend.length +Base.summary(backend::DatasetBackend) = string("DatasetBackend(", backend.start, ", ", backend.length, ")") + +new_backend(b::DatasetBackend{native, cache_data}, start, length) where {native, cache_data} = + DatasetBackend{native, cache_data}(start, length, b.inpainting, b.metadata) + +on_native_grid(::DatasetBackend{native}) where native = native +cache_inpainted_data(::DatasetBackend{native, cache_data}) where {native, cache_data} = cache_data + +const DatasetFieldTimeSeries{N} = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:DatasetBackend{N}} where N + +# Default per-file set! — each metadata index corresponds to its own file. +# Used by datasets whose native files hold one time instance per file, such +# as ECCO4 / EN4 / WOA monthly climatologies. Datasets whose files hold +# multiple time instances (e.g. JRA55) dispatch on a more specific backend +# signature keyed off the metadata's dataset type. +function set!(fts::DatasetFieldTimeSeries, backend=fts.backend) + inpainting = backend.inpainting + cache_data = cache_inpainted_data(backend) + + for t in time_indices(fts) + metadatum = @inbounds backend.metadata[t] + set!(fts[t], metadatum; inpainting, cache_inpainted_data=cache_data) + end + + fill_halo_regions!(fts) + + return nothing +end diff --git a/src/DataWrangling/metadata.jl b/src/DataWrangling/metadata.jl index a6f3fb092..b26caffd2 100644 --- a/src/DataWrangling/metadata.jl +++ b/src/DataWrangling/metadata.jl @@ -86,6 +86,11 @@ Metadata(name, dataset, dates, region, dir) = Metadata(name, dataset, dates, reg is_three_dimensional(::Metadata) = true z_interfaces(md::Metadata) = z_interfaces(md.dataset) + +# NetCDF coordinate-variable names. Default follows CF standard; datasets +# whose files use different names (e.g. JRA55 uses `lon`/`lat`) override. +longitude_name(::Metadata) = "longitude" +latitude_name(::Metadata) = "latitude" longitude_interfaces(md::Metadata) = longitude_interfaces(md.dataset) latitude_interfaces(md::Metadata) = latitude_interfaces(md.dataset) diff --git a/src/DataWrangling/metadata_field.jl b/src/DataWrangling/metadata_field.jl index d5656f320..004095d3a 100644 --- a/src/DataWrangling/metadata_field.jl +++ b/src/DataWrangling/metadata_field.jl @@ -2,9 +2,8 @@ using NCDatasets using JLD2 using NumericalEarth.InitialConditions: interpolate! using Statistics: median -using Oceananigans.Grids: λnodes, φnodes +using Oceananigans.Grids: λnodes, φnodes, Periodic, Bounded using Oceananigans.Architectures: on_architecture -using Oceananigans.Fields: fractional_x_index, fractional_y_index import Oceananigans.Fields: set!, Field, location @@ -23,19 +22,36 @@ restrict_location((LX, LY, LZ), ::Column) = (Nothing, Nothing, LZ) ##### restrict(::Nothing, interfaces, N) = interfaces, N +restrict(::Nothing, interfaces::NTuple{2,Any}, N) = interfaces, N +restrict(::Nothing, interfaces::AbstractVector, N) = interfaces, N + +# Snap bbox outward to native cell faces so restricted centers land on native centers +function restrict(bbox_interfaces, interfaces::NTuple{2,Any}, N) + left, right = interfaces + Δ = (right - left) / N + i⁻ = max(floor(Int, (bbox_interfaces[1] - left) / Δ), 0) + i⁺ = min(ceil( Int, (bbox_interfaces[2] - left) / Δ), N) + if i⁺ <= i⁻ + i⁺ = min(i⁻ + 1, N) + i⁻ = max(i⁺ - 1, 0) + end + return (left + i⁻ * Δ, left + i⁺ * Δ), i⁺ - i⁻ +end -# TODO support stretched native grids -function restrict(bbox_interfaces, interfaces, N) - LΔ = interfaces[2] - interfaces[1] - Δ = LΔ / N - grid_interfaces = (bbox_interfaces[1] - Δ/2, - bbox_interfaces[2] + Δ/2) +# Stretched native grid: snap outward to the nearest native cell interfaces. +function restrict(bbox_interfaces, interfaces::AbstractVector, N) + i⁻ = max(searchsortedlast(interfaces, bbox_interfaces[1]), 1) + i⁺ = min(searchsortedfirst(interfaces, bbox_interfaces[2]), length(interfaces)) + rN = max(i⁺ - i⁻, 1) + return interfaces[i⁻:i⁺], rN +end - rΔ = grid_interfaces[2] - grid_interfaces[1] - ϵ = rΔ / LΔ - rN = ceil(Int, ϵ * N) # Round up to ensure bounding box is covered +native_convention_longitude(::Nothing, native) = nothing - return grid_interfaces, rN +# Map a bbox longitude into the native longitude convention +function native_convention_longitude(bbox_longitude, native) + λ⁻ = convert_to_λ₀_λ₀_plus360(bbox_longitude[1], native[1]) + return (λ⁻, λ⁻ + (bbox_longitude[2] - bbox_longitude[1])) end """ @@ -48,53 +64,65 @@ and a column `RectilinearGrid` for `Column` regions. native_grid(metadata::Metadata, arch=CPU(); halo=(3, 3, 3)) = construct_native_grid(metadata, metadata.region, arch; halo) -# Full global grid (no region restriction) +# 2D-only datasets (surface forcing like JRA55) skip the z dimension. function construct_native_grid(metadata, ::Nothing, arch; halo) - Nx, Ny, Nz, _ = size(metadata) - z = z_interfaces(metadata) FT = eltype(metadata) longitude = longitude_interfaces(metadata) latitude = latitude_interfaces(metadata) + Nx, Ny, Nz = size(metadata) - grid = LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny, Nz), - halo, longitude, latitude, z) - return grid + if is_three_dimensional(metadata) + z = z_interfaces(metadata) + return LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny, Nz), + halo, longitude, latitude, z) + else + return LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny), + halo = halo[1:2], longitude, latitude, + topology = (Periodic, Bounded, Flat)) + end end -# BoundingBox-restricted LatitudeLongitudeGrid function construct_native_grid(metadata, bbox::BoundingBox, arch; halo) - Nx, Ny, Nz, _ = size(metadata) - z = z_interfaces(metadata) FT = eltype(metadata) - longitude = longitude_interfaces(metadata) - latitude = latitude_interfaces(metadata) + native_longitude = longitude_interfaces(metadata) + native_latitude = latitude_interfaces(metadata) + + # Map the bbox into the native longitude convention. + bbox_lon = native_convention_longitude(bbox.longitude, native_longitude) - # TODO: can we restrict in `z` as well? - longitude, Nx = restrict(bbox.longitude, longitude, Nx) - latitude, Ny = restrict(bbox.latitude, latitude, Ny) + Nx, Ny, Nz = size(metadata) + longitude, Nx = restrict(bbox_lon, native_longitude, Nx) + latitude, Ny = restrict(bbox.latitude, native_latitude, Ny) - # Clamp halo so it does not exceed grid size in any dimension - halo = min.(halo, (Nx, Ny, Nz)) + TX = infer_longitudinal_topology(native_longitude, longitude) - grid = LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny, Nz), - halo, longitude, latitude, z) - return grid + if is_three_dimensional(metadata) + z = z_interfaces(metadata) + return LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny, Nz), + halo, longitude, latitude, z, + topology = (TX, Bounded, Bounded)) + else + return LatitudeLongitudeGrid(arch, FT; size = (Nx, Ny), + halo = halo[1:2], longitude, latitude, + topology = (TX, Bounded, Flat)) + end end -# Column RectilinearGrid +# 2D-only datasets collapse to (Flat, Flat, Flat); 3D keep z Bounded. function construct_native_grid(metadata, col::Column, arch; halo) - _, _, Nz, _ = size(metadata) - z = z_interfaces(metadata) FT = eltype(metadata) + x = FT(col.longitude) + y = FT(col.latitude) - grid = RectilinearGrid(arch, FT; - size = Nz, - x = FT(col.longitude), - y = FT(col.latitude), - z, - halo = halo[3], - topology = (Flat, Flat, Bounded)) - return grid + if is_three_dimensional(metadata) + _, _, Nz, _ = size(metadata) + z = z_interfaces(metadata) + return RectilinearGrid(arch, FT; size = Nz, halo = halo[3], + x, y, z, topology = (Flat, Flat, Bounded)) + else + return RectilinearGrid(arch, FT; size = (), halo = (), + x, y, topology = (Flat, Flat, Flat)) + end end """ @@ -136,7 +164,7 @@ end architecture = CPU(), inpainting = default_inpainting(metadata), mask = nothing, - halo = (7, 7, 7), + halo = (3, 3, 3), cache_inpainted_data = true) Return a `Field` on `architecture` described by `metadata` with `halo` size. @@ -153,12 +181,12 @@ function Field(metadata::Metadatum, arch=CPU(); download_dataset(metadata) - # Column regions need special handling: the downloaded file may contain - # more data than a single column (e.g. CopernicusMarine returns a small - # grid around the point). Load onto an intermediate grid from the file's - # actual dimensions, then extract the column. + # Inpainting on a (Flat, Flat, *) column field is meaningless and the + # iterative algorithm doesn't terminate gracefully without horizontal + # neighbours; the NaN-aware bracket-blend in `set_region_data!` handles + # land cells directly. if metadata.region isa Column - return column_field_from_file(metadata, arch; inpainting, mask, halo, cache_inpainted_data) + inpainting = nothing end grid = native_grid(metadata, arch; halo) @@ -168,24 +196,23 @@ function Field(metadata::Metadatum, arch=CPU(); if !isnothing(inpainting) inpainted_path = inpainted_metadata_path(metadata) if isfile(inpainted_path) - file = jldopen(inpainted_path, "r") - maxiter = file["inpainting_maxiter"] - - # read data if generated with the same inpainting - if maxiter == inpainting.maxiter - data = file["data"] - close(file) - try - copyto!(parent(field), data) - return field - catch - @warn "Could not load existing inpainted data at $inpainted_path.\n" * - "Re-inpainting and saving data..." - rm(inpainted_path, force=true) + # apply a load guard for corrupted files + loaded = false + try + jldopen(inpainted_path, "r") do file + if haskey(file, "inpainting_maxiter") && + file["inpainting_maxiter"] == inpainting.maxiter + copyto!(parent(field), file["data"]) + loaded = true + end end + catch err + @warn "Could not load existing inpainted data at $inpainted_path; " * + "re-inpainting and saving data..." exception=err + rm(inpainted_path, force=true) + loaded = false end - - close(file) + loaded && return field end end @@ -255,204 +282,22 @@ function set!(target_field::Field, metadata::Metadatum; kw...) return target_field end -##### -##### Column field construction -##### - -function column_field_from_file(metadata, arch; halo=(3, 3, 3), kw...) - column_grid = native_grid(metadata, arch; halo) - - # Read the file's actual dimensions to build a matching intermediate grid - path = metadata_path(metadata) - ds = Dataset(path) - varname = dataset_variable_name(metadata) - var = ds[varname] - data_size = size(var) - Nx_file, Ny_file = data_size[1], data_size[2] - - # Read coordinate arrays - lon_dimname = NCDatasets.dimnames(var)[1] - lat_dimname = NCDatasets.dimnames(var)[2] - λ = haskey(ds, lon_dimname) ? ds[lon_dimname][:] : ds["longitude"][:] - φ = haskey(ds, lat_dimname) ? ds[lat_dimname][:] : ds["latitude"][:] - close(ds) - - if reversed_latitude_axis(metadata.dataset) - reverse!(φ) - end - - _, _, Nz, _ = size(metadata) - - # Validate that the cached file's vertical extent matches the dataset - # configuration. A common cause of mismatch is a stale cache from a previous - # run with a different vertical configuration (e.g. ERA5 `pressure_levels`). - if is_three_dimensional(metadata) && length(data_size) >= 3 && data_size[3] != Nz - error("Cached file $(path) has $(data_size[3]) vertical levels, but the " * - "dataset configuration expects $Nz. This is most likely a stale " * - "cache from a previous run with a different vertical configuration. " * - "Delete the file and re-run.") - end - z = z_interfaces(metadata) - FT = eltype(metadata) - - # Build cell interfaces from centers - Δλ = Nx_file > 1 ? λ[2] - λ[1] : FT(1) - λf = range(λ[1] - Δλ/2, stop = λ[end] + Δλ/2, length = Nx_file + 1) - - Δφ = Ny_file > 1 ? φ[2] - φ[1] : FT(1) - φf = range(φ[1] - Δφ/2, stop = φ[end] + Δφ/2, length = Ny_file + 1) - - halo = min.(halo, (Nx_file, Ny_file, Nz)) - - intermediate_grid = LatitudeLongitudeGrid(arch, FT; - size = (Nx_file, Ny_file, Nz), - halo, longitude = λf, latitude = φf, z) - - # Load data onto intermediate grid (no inpainting — columns have no horizontal neighbors) - LX, LY, LZ = dataset_location(metadata.dataset, metadata.name) - intermediate_field = Field{LX, LY, LZ}(intermediate_grid) - - data = retrieve_data(metadata) - set_metadata_field!(intermediate_field, data, metadata) - fill_halo_regions!(intermediate_field) - - # Extract column - _, _, LZ_col = location(metadata) - col_field = Field{Nothing, Nothing, LZ_col}(column_grid) - extract_column!(col_field, intermediate_field, metadata.region) - - return col_field -end - -##### -##### Column extraction utilities -##### - -# Dispatch extraction on interpolation method -function extract_column!(column_field, intermediate_field, col::Column) - extract_column!(column_field, intermediate_field, col, col.interpolation) -end - -function extract_column!(column_field, intermediate_field, col, ::Linear) - grid = intermediate_field.grid - arch = architecture(grid) - LX, LY, LZ = Oceananigans.Fields.location(intermediate_field) - locs = (LX(), LY(), LZ()) - - # Fractional indices (1-based, continuous) - fi = fractional_x_index(col.longitude, locs, grid) - fj = fractional_y_index(col.latitude, locs, grid) - - # Lower-left index and weights - i₁ = clamp(floor(Int, fi), 1, size(grid, 1)) - j₁ = clamp(floor(Int, fj), 1, size(grid, 2)) - i₂ = clamp(i₁ + 1, 1, size(grid, 1)) - j₂ = clamp(j₁ + 1, 1, size(grid, 2)) - - wx = clamp(fi - floor(fi), 0, 1) - wy = clamp(fj - floor(fj), 0, 1) - - launch!(arch, column_field.grid, :z, _bilinear_interpolate_column!, - column_field, intermediate_field, i₁, j₁, i₂, j₂, wx, wy) - - return nothing -end - -@kernel function _bilinear_interpolate_column!(column_field, source, i₁, j₁, i₂, j₂, wx, wy) - k = @index(Global, Linear) - @inbounds begin - v00 = source[i₁, j₁, k] - v10 = source[i₂, j₁, k] - v01 = source[i₁, j₂, k] - v11 = source[i₂, j₂, k] - column_field[1, 1, k] = (1 - wx) * (1 - wy) * v00 + - wx * (1 - wy) * v10 + - (1 - wx) * wy * v01 + - wx * wy * v11 - end -end - -function extract_column!(column_field, intermediate_field, col, ::Nearest) - grid = intermediate_field.grid - arch = architecture(grid) - LX, LY, LZ = Oceananigans.Fields.location(intermediate_field) - locs = (LX(), LY(), LZ()) # fractional index functions expect instances, not types - - # Use Oceananigans' fractional index machinery (handles cyclic longitude etc.) - i★ = round(Int, fractional_x_index(col.longitude, locs, grid)) - j★ = round(Int, fractional_y_index(col.latitude, locs, grid)) - - launch!(arch, column_field.grid, :z, copy_column!, column_field, intermediate_field, i★, j★) - - return nothing -end - -@kernel function copy_column!(column_field, source_field, i★, j★) - k = @index(Global, Linear) - @inbounds column_field[1, 1, k] = source_field[i★, j★, k] -end - -# manglings -struct ShiftSouth end -struct AverageNorthSouth end - -@inline mangle(i, j, data, ::Nothing) = @inbounds data[i, j] -@inline mangle(i, j, data, ::ShiftSouth) = @inbounds data[i, j-1] -@inline mangle(i, j, data, ::AverageNorthSouth) = @inbounds (data[i, j+1] + data[i, j]) / 2 - -@inline mangle(i, j, k, data, ::Nothing) = @inbounds data[i, j, k] -@inline mangle(i, j, k, data, ::ShiftSouth) = @inbounds data[i, j-1, k] -@inline mangle(i, j, k, data, ::AverageNorthSouth) = @inbounds (data[i, j+1, k] + data[i, j, k]) / 2 - function set_metadata_field!(field, data, metadatum) - grid = field.grid - arch = architecture(grid) - - Nx, Ny, Nz = size(metadatum) - - mangling = if size(data, 2) == Ny-1 - @debug "Shifting field southward" - ShiftSouth() - elseif size(data, 2) == Ny+1 - @debug "Averaging field in north-south dir" - AverageNorthSouth() - else - nothing - end - - conversion = conversion_units(metadatum) - - if ndims(data) == 2 - _kernel = _set_2d_metadata_field! - spec = :xy - else - _kernel = _set_3d_metadata_field! - spec = :xyz - end - - data = on_architecture(arch, data) - Oceananigans.Utils.launch!(arch, grid, spec, _kernel, field, data, mangling, conversion) - + full_data = ndims(data) == 2 ? reshape(data, size(data, 1), size(data, 2), 1) : data + λc, φc = read_file_coords(metadatum) + set_region_data!(field, full_data, λc, φc, metadatum) return nothing end -@kernel function _set_2d_metadata_field!(field, data, mangling, conversion) - i, j = @index(Global, NTuple) - FT = eltype(field) - d = mangle(i, j, data, mangling) - d = nan_convert_missing(FT, d) - d = convert_units(d, conversion) - @inbounds field[i, j, 1] = d -end - -@kernel function _set_3d_metadata_field!(field, data, mangling, conversion) - i, j, k = @index(Global, NTuple) - FT = eltype(field) - d = mangle(i, j, k, data, mangling) - d = nan_convert_missing(FT, d) - d = convert_units(d, conversion) - - @inbounds field[i, j, k] = d +# Read the lon/lat cell centres from the NetCDF file using the names supplied +# by the dataset's `longitude_name` / `latitude_name` traits. +function read_file_coords(metadatum) + ds = Dataset(metadata_path(metadatum)) + λc = ds[longitude_name(metadatum)][:] + φc = ds[latitude_name(metadatum)][:] + close(ds) + reversed_latitude_axis(metadatum.dataset) && reverse!(φc) + return λc, φc end ##### diff --git a/src/DataWrangling/metadata_field_time_series.jl b/src/DataWrangling/metadata_field_time_series.jl index 142eb6b07..42aa50c5a 100644 --- a/src/DataWrangling/metadata_field_time_series.jl +++ b/src/DataWrangling/metadata_field_time_series.jl @@ -1,73 +1,4 @@ -using Oceananigans.Architectures: AbstractArchitecture -using Oceananigans.Grids: AbstractGrid -using Oceananigans.Fields: interpolate! -using Oceananigans.OutputReaders: Cyclical, AbstractInMemoryBackend, FlavorOfFTS, time_indices - -import Oceananigans.OutputReaders: new_backend, update_field_time_series!, FieldTimeSeries - -@inline instantiate(T::DataType) = T() -@inline instantiate(T) = T - -struct DatasetBackend{N, C, I, M} <: AbstractInMemoryBackend{Int} - start :: Int - length :: Int - inpainting :: I - metadata :: M - - function DatasetBackend{N, C}(start::Int, length::Int, inpainting, metadata) where {N, C} - M = typeof(metadata) - I = typeof(inpainting) - return new{N, C, I, M}(start, length, inpainting, metadata) - end -end - -Adapt.adapt_structure(to, b::DatasetBackend{N, C}) where {N, C} = - DatasetBackend{N, C}(b.start, b.length, nothing, nothing) - -""" - DatasetBackend(length, metadata; - on_native_grid = false, - cache_inpainted_data = false, - inpainting = NearestNeighborInpainting(Inf)) - -Represent a FieldTimeSeries backed by the backend that corresponds to the -dataset with `metadata` (e.g., netCDF). Each time instance is stored in an -individual file. -""" -function DatasetBackend(length, metadata; - on_native_grid = false, - cache_inpainted_data = false, - inpainting = NearestNeighborInpainting(Inf)) - - return DatasetBackend{on_native_grid, cache_inpainted_data}(1, length, inpainting, metadata) -end - -Base.length(backend::DatasetBackend) = backend.length -Base.summary(backend::DatasetBackend) = string("DatasetBackend(", backend.start, ", ", backend.length, ")") - -new_backend(b::DatasetBackend{native, cache_data}, start, length) where {native, cache_data} = - DatasetBackend{native, cache_data}(start, length, b.inpainting, b.metadata) - -on_native_grid(::DatasetBackend{native}) where native = native -cache_inpainted_data(::DatasetBackend{native, cache_data}) where {native, cache_data} = cache_data - -const DatasetFieldTimeSeries{N} = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:DatasetBackend{N}} where N - -function set!(fts::DatasetFieldTimeSeries) - backend = fts.backend - inpainting = backend.inpainting - cache_data = cache_inpainted_data(backend) - - for t in time_indices(fts) - # Set each element of the time-series to the associated file - metadatum = @inbounds backend.metadata[t] - set!(fts[t], metadatum; inpainting, cache_inpainted_data=cache_data) - end - - fill_halo_regions!(fts) - - return nothing -end +import Oceananigans.OutputReaders: update_field_time_series!, FieldTimeSeries """ FieldTimeSeries(metadata::Metadata [, arch_or_grid=CPU() ]; @@ -112,16 +43,22 @@ function FieldTimeSeries(metadata::Metadata, grid::AbstractGrid; inpainting = default_inpainting(metadata), cache_inpainted_data = true) - # Make sure all the required individual files are downloaded download_dataset(metadata) + times = native_times(metadata) + + # Make sure we do not use more indices then the ones available! + if length(times) < time_indices_in_memory + time_indices_in_memory = length(times) + end + inpainting isa Int && (inpainting = NearestNeighborInpainting(inpainting)) - is_native = grid == native_grid(metadata) - backend = DatasetBackend(time_indices_in_memory, metadata; on_native_grid=is_native, inpainting, cache_inpainted_data) + on_native_grid = grid == native_grid(metadata, architecture(grid)) + backend = DatasetBackend(time_indices_in_memory, metadata; on_native_grid, inpainting, cache_inpainted_data) - times = native_times(metadata) loc = LX, LY, LZ = location(metadata) boundary_conditions = FieldBoundaryConditions(grid, instantiate.(loc)) + fts = FieldTimeSeries{LX, LY, LZ}(grid, times; backend, time_indexing, boundary_conditions) set!(fts) diff --git a/src/DataWrangling/set_region_data.jl b/src/DataWrangling/set_region_data.jl new file mode 100644 index 000000000..4eceb98bb --- /dev/null +++ b/src/DataWrangling/set_region_data.jl @@ -0,0 +1,219 @@ +using Oceananigans.Utils: launch! +using Oceananigans.Architectures: AbstractArchitecture, architecture +using Oceananigans.Grids: AbstractGrid, Periodic, Bounded, λnodes, φnodes +using Oceananigans.Fields: Field, interior, interpolate! +using Oceananigans.Fields: convert_to_λ₀_λ₀_plus360 +using GPUArraysCore: @allowscalar + +##### +##### Region helpers shared across dataset backends +##### + +function compute_bounding_nodes(grid, LH, hnodes) + hg = hnodes(grid, LH()) + h₁ = @allowscalar minimum(hg) + h₂ = @allowscalar maximum(hg) + return h₁, h₂ +end + +# `ε` forgives Float32 to Float64 promotion noise so the slice doesn't lose a +# cell at each end when grid centers are compared against file centers. +function compute_bounding_indices(bounds::Tuple, hc) + h₁, h₂ = bounds + Nh = length(hc) + ε = eps(Float32) * max(one(eltype(hc)), abs(h₁), abs(h₂)) + i₁ = max(searchsortedfirst(hc, h₁ - ε), 1) + i₂ = min( searchsortedlast(hc, h₂ + ε), Nh) + return i₁, i₂ +end + +# Periodic only when the restricted span equals the full native span. +function infer_longitudinal_topology(full_longitude, restricted_longitude) + full_span = full_longitude[end] - full_longitude[1] + restricted_span = restricted_longitude[end] - restricted_longitude[1] + return restricted_span ≈ full_span ? Periodic : Bounded +end + +##### +##### Mangling utilities +##### + +struct ShiftSouth end +struct AverageNorthSouth end + +# `mangle(i, j, k, data, mangling)` reads file `data` at metadata-grid index `(i, j, k)`, accounting +# for staggered lat-axis offsets. Used inside the region-aware kernel. +@inline mangle(i, j, k, data, ::Nothing) = @inbounds data[i, j, k] +@inline mangle(i, j, k, data, ::ShiftSouth) = @inbounds data[i, max(j - 1, 1), k] +@inline mangle(i, j, k, data, ::AverageNorthSouth) = @inbounds (data[i, j, k] + data[i, j + 1, k]) / 2 + +##### +##### Region-aware filling for Fields and FieldTimeSeries via a single kernel. +##### +##### `read_data(data, i, j, k, region_info, mangling)` is the only access point: it composes the +##### file-axis offset (region) with the lat-axis remap (mangling). All region/mangling combinations +##### go through one kernel that handles NaN + unit conversion in the same pass. +##### + +struct BBoxOffset + di :: Int + dj :: Int +end + +struct ColumnInfo{F, I} + i⁻ :: Int + i⁺ :: Int + j⁻ :: Int + j⁺ :: Int + wx :: F + wy :: F + ℑ :: I +end + +# `region_info` resolves the target's region to a kernel-friendly struct. +region_info(::Nothing, target, λc, φc) = nothing + +function region_info(::BoundingBox, target, λc, φc) + LX, LY, _ = Oceananigans.Fields.location(target) + λmin, λmax = compute_bounding_nodes(target.grid, LX, λnodes) + φmin, φmax = compute_bounding_nodes(target.grid, LY, φnodes) + + # Shift the target's longitude into the file's `[λc[1], λc[1]+360)` + if !isempty(λc) + λmin = convert_to_λ₀_λ₀_plus360(λmin, λc[1]) + λmax = convert_to_λ₀_λ₀_plus360(λmax, λc[1]) + end + + i₁, _ = compute_bounding_indices((λmin, λmax), λc) + j₁, _ = compute_bounding_indices((φmin, φmax), φc) + + Nx, Ny, _ = size(target) + di = clamp(i₁ - 1, 0, max(length(λc) - Nx, 0)) + dj = clamp(j₁ - 1, 0, max(length(φc) - Ny, 0)) + return BBoxOffset(di, dj) +end + +function region_info(col::Column, target, λc, φc) + i⁻, i⁺, wx = bracket_with_weight(λc, col.longitude; period = infer_longitudinal_period(λc)) + j⁻, j⁺, wy = bracket_with_weight(φc, col.latitude) # latitude is never cyclic + FT = eltype(target) + return ColumnInfo(i⁻, i⁺, j⁻, j⁺, FT(wx), FT(wy), col.interpolation) +end + +# 360 if `λc` spans the full globe (cyclic), else `nothing`. +function infer_longitudinal_period(λc) + length(λc) < 2 && return nothing + Δ = λc[2] - λc[1] + span = λc[end] - λc[1] + Δ + return span ≈ 360 ? 360 : nothing +end + +# Cyclic-aware bracketing. With `period`, the cell between `coords[end]` and `coords[1] + period` is the wrap cell: +# returns `(n, 1, w)` so the blend reads `data[n, …]` and `data[1, …]`. +function bracket_with_weight(coords, x; period = nothing) + n = length(coords) + + # Single-cell axis: nothing to bracket — point both corners at the only cell. + n ≤ 1 && return 1, 1, zero(x) + + if !isnothing(period) + x = coords[1] + mod(x - coords[1], period) + if x > coords[end] + Δ = (coords[1] + period) - coords[end] + w = (x - coords[end]) / Δ + return n, 1, clamp(w, 0, 1) + end + end + + i⁺ = searchsortedfirst(coords, x) + i⁺ = clamp(i⁺, 2, n) + i⁻ = i⁺ - 1 + Δ = coords[i⁺] - coords[i⁻] + w = Δ == 0 ? zero(x) : (x - coords[i⁻]) / Δ + return i⁻, i⁺, clamp(w, 0, 1) +end + +# `mangling_for` detects a file/grid lat-axis offset from the data shape. +function mangling_for(metadata, data_lat_count) + Ny = size(metadata)[2] + return data_lat_count == Ny - 1 ? ShiftSouth() : + data_lat_count == Ny + 1 ? AverageNorthSouth() : + nothing +end + +# `read_data(data, i, j, k, region, mangling, FT)` returns the file value at +# the grid's (i, j, k) as `FT`, with `Missing` converted to NaN. +@inline read_data(data, i, j, k, ::Nothing, mangling, FT) = nan_convert_missing(FT, mangle(i, j, k, data, mangling)) +@inline read_data(data, i, j, k, b::BBoxOffset, mangling, FT) = nan_convert_missing(FT, mangle(i + b.di, j + b.dj, k, data, mangling)) +@inline read_data(data, _, _, k, c::ColumnInfo, mangling, FT) = blend(data, c, k, mangling, c.ℑ, FT) + +# NaN-aware bilinear blend: drop NaN corners and renormalise weights. +# Returns NaN only when all four corners are land. +@inline function blend(data, c, k, mangling, ::Linear, FT) + d00 = nan_convert_missing(FT, mangle(c.i⁻, c.j⁻, k, data, mangling)) + d10 = nan_convert_missing(FT, mangle(c.i⁺, c.j⁻, k, data, mangling)) + d01 = nan_convert_missing(FT, mangle(c.i⁻, c.j⁺, k, data, mangling)) + d11 = nan_convert_missing(FT, mangle(c.i⁺, c.j⁺, k, data, mangling)) + w00 = (1 - c.wx) * (1 - c.wy) * !isnan(d00) + w10 = c.wx * (1 - c.wy) * !isnan(d10) + w01 = (1 - c.wx) * c.wy * !isnan(d01) + w11 = c.wx * c.wy * !isnan(d11) + Σw = w00 + w10 + w01 + w11 + Σw == 0 && return convert(FT, NaN) + return (w00 * ifelse(isnan(d00), zero(FT), d00) + + w10 * ifelse(isnan(d10), zero(FT), d10) + + w01 * ifelse(isnan(d01), zero(FT), d01) + + w11 * ifelse(isnan(d11), zero(FT), d11)) / Σw +end + +@inline function blend(data, c, k, mangling, ::Nearest, FT) + i = c.wx ≥ 0.5 ? c.i⁺ : c.i⁻ + j = c.wy ≥ 0.5 ? c.j⁺ : c.j⁻ + near = nan_convert_missing(FT, mangle(i, j, k, data, mangling)) + # If the closest corner is land, fall back to the NaN-aware Linear blend. + return isnan(near) ? blend(data, c, k, mangling, Linear(), FT) : near +end + +@kernel function _set_region_kernel!(dst, data, region, mangling, conversion, FT) + i, j, k = @index(Global, NTuple) + d = read_data(data, i, j, k, region, mangling, FT) + d = convert_units(d, conversion) + @inbounds dst[i, j, k] = d +end + +""" + set_region_data!(target, data, λc, φc, metadata) + +Fill the region of `target` (Field or FieldTimeSeries) implied by `metadata.region` from `data`, +applying mangling, NaN conversion, and unit conversion in a single GPU-friendly kernel pass. +""" +function set_region_data!(target::Field, data, λc, φc, metadata; + mangling = mangling_for(metadata, size(data, 2)), + conversion = conversion_units(metadata)) + + region = region_info(metadata.region, target, λc, φc) + FT = eltype(target) + grid = target.grid + arch = architecture(grid) + data = on_architecture(arch, data) + launch!(arch, grid, :xyz, _set_region_kernel!, interior(target), data, region, mangling, conversion, FT) + return nothing +end + +function set_region_data!(target::FieldTimeSeries, data, λc, φc, metadata; + mangling = mangling_for(metadata, size(data, 2)), + conversion = conversion_units(metadata), + slot_indices = 1:size(target, 4)) + + region = region_info(metadata.region, target, λc, φc) + grid = target.grid + arch = architecture(grid) + FT = eltype(target) + data = on_architecture(arch, data) + for (data_time, slot_time) in zip(axes(data, 4), slot_indices) + dest = view(interior(target), :, :, :, slot_time) + slice = view(data, :, :, :, data_time) + launch!(arch, grid, :xyz, _set_region_kernel!, dest, slice, region, mangling, conversion, FT) + end + return nothing +end diff --git a/src/NumericalEarth.jl b/src/NumericalEarth.jl index fc5731492..cf0c48328 100644 --- a/src/NumericalEarth.jl +++ b/src/NumericalEarth.jl @@ -32,7 +32,6 @@ export os_papa_prescribed_fluxes, os_papa_prescribed_flux_boundary_conditions, OSPapaHourly, - JRA55NetCDFBackend, regrid_bathymetry, Metadata, Metadatum, @@ -137,7 +136,6 @@ using NumericalEarth.DataWrangling.EN4 using NumericalEarth.DataWrangling.ORCA using NumericalEarth.DataWrangling.WOA using NumericalEarth.DataWrangling.JRA55 -using NumericalEarth.DataWrangling.JRA55: JRA55NetCDFBackend using NumericalEarth.DataWrangling.OSPapa using PrecompileTools: @setup_workload, @compile_workload diff --git a/test/download_utils.jl b/test/download_utils.jl index 23a98118e..469cba410 100644 --- a/test/download_utils.jl +++ b/test/download_utils.jl @@ -10,11 +10,13 @@ function emit_ci_warning(title, message) end function download_from_artifacts(filepath::AbstractString) - if !isfile(filepath) - filename = basename(filepath) - fallback_url = ARTIFACTS_BASE_URL * filename - @info "Downloading $filename from NumericalEarthArtifacts fallback..." - Downloads.download(fallback_url, filepath) + filename = basename(filepath) + fallback_url = ARTIFACTS_BASE_URL * filename + @info "Downloading $filename from NumericalEarthArtifacts fallback..." + mktemp(dirname(filepath)) do tmppath, tmpio + close(tmpio) + Downloads.download(fallback_url, tmppath) + mv(tmppath, filepath; force=true) end end diff --git a/test/runtests.jl b/test/runtests.jl index ef9d1bc50..e4dc876e2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -78,7 +78,7 @@ function __init__() ##### try - atmosphere = JRA55PrescribedAtmosphere(backend=JRA55NetCDFBackend(2)) + atmosphere = JRA55PrescribedAtmosphere(time_indices_in_memory=2) catch e @warn "Original JRA55 download failed, trying NumericalEarthArtifacts fallback..." exception=(e, catch_backtrace()) emit_ci_warning("Broken JRA55 download", "Original source failed during init") @@ -86,7 +86,7 @@ function __init__() datum = Metadatum(name; dataset=JRA55.RepeatYearJRA55()) download_from_artifacts(metadata_path(datum)) end - atmosphere = JRA55PrescribedAtmosphere(backend=JRA55NetCDFBackend(2)) + atmosphere = JRA55PrescribedAtmosphere(time_indices_in_memory=2) end ##### diff --git a/test/test_checkpointer.jl b/test/test_checkpointer.jl index 3d058af42..31cbb53f1 100644 --- a/test/test_checkpointer.jl +++ b/test/test_checkpointer.jl @@ -27,9 +27,8 @@ using Oceananigans.OutputWriters: Checkpointer set!(sea_ice.model, h=hi, ℵ=hi) # Create atmosphere, land, and radiation - backend = JRA55NetCDFBackend(4) - atmosphere = JRA55PrescribedAtmosphere(arch; backend) - land = JRA55PrescribedLand(arch; backend) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4) + land = JRA55PrescribedLand(arch; time_indices_in_memory=4) return OceanSeaIceModel(ocean, sea_ice; atmosphere, land) end diff --git a/test/test_column_field.jl b/test/test_column_field.jl index bb9004e15..27e99507d 100644 --- a/test/test_column_field.jl +++ b/test/test_column_field.jl @@ -3,8 +3,8 @@ include("runtests_setup.jl") using NumericalEarth.DataWrangling: Column, Linear, Nearest, BoundingBox, native_grid, restrict_location, dataset_location - -using NumericalEarth.DataWrangling: extract_column! +using NumericalEarth.DataWrangling: bracket_with_weight, infer_longitudinal_period, + region_info, blend, ColumnInfo using Oceananigans using Oceananigans.BoundaryConditions: fill_halo_regions! @@ -15,121 +15,103 @@ using Oceananigans.Grids: λnodes, φnodes, topology, Flat, Bounded, Periodic const test_longitude = 12.0 const test_latitude = -50.0 -@testset "extract_column! with Nearest interpolation" begin - for arch in test_architectures - A = typeof(arch) - @testset "Nearest extraction on $A" begin - # Create a LatitudeLongitudeGrid with spatially varying data - intermediate_grid = LatitudeLongitudeGrid(arch; - size = (4, 4, 2), - longitude = (0, 4), - latitude = (0, 4), - z = (-20, 0)) - - intermediate_field = CenterField(intermediate_grid) - - # Set distinct values at each horizontal point - for i in 1:4, j in 1:4, k in 1:2 - @allowscalar intermediate_field[i, j, k] = 10 * i + j + 0.1 * k - end - fill_halo_regions!(intermediate_field) - - # Column near grid point (3, 2) → lon≈2.5, lat≈1.5 - col = Column(2.5, 1.5; interpolation=Nearest()) - column_grid = RectilinearGrid(arch; - size = 2, - x = 2.5, - y = 1.5, - z = (-20, 0), - halo = 3, - topology = (Flat, Flat, Bounded)) +@testset "bracket_with_weight (non-cyclic)" begin + coords = [0.5, 1.5, 2.5, 3.5] + + # Interior point, exact midpoint between cells. + i⁻, i⁺, w = bracket_with_weight(coords, 2.0) + @test (i⁻, i⁺) == (2, 3) + @test w ≈ 0.5 + + # Off-grid below: clamps to first interval. + i⁻, i⁺, w = bracket_with_weight(coords, -1.0) + @test (i⁻, i⁺) == (1, 2) + @test w == 0.0 + + # Off-grid above: clamps to last interval. + i⁻, i⁺, w = bracket_with_weight(coords, 5.0) + @test (i⁻, i⁺) == (3, 4) + @test w == 1.0 + + # On the right-most centre: weight = 1. + i⁻, i⁺, w = bracket_with_weight(coords, 3.5) + @test (i⁻, i⁺) == (3, 4) + @test w ≈ 1.0 + + # Single-cell axis: nothing to bracket; both corners point at the only cell. + # GLORYS via CopernicusMarine returns 1-cell-wide chunked files for Column queries. + i⁻, i⁺, w = bracket_with_weight([7.5], 7.5) + @test (i⁻, i⁺, w) == (1, 1, 0.0) + i⁻, i⁺, w = bracket_with_weight([7.5], 99.0) + @test (i⁻, i⁺, w) == (1, 1, 0.0) +end - column_field = Field{Nothing, Nothing, Center}(column_grid) +@testset "bracket_with_weight (cyclic wrap)" begin + coords = collect(0.5:1.0:359.5) # global 1° centres + n = length(coords) - extract_column!(column_field, intermediate_field, col, Nearest()) + # Interior point — period is a no-op there. + i⁻, i⁺, w = bracket_with_weight(coords, 180.0; period = 360) + @test (i⁻, i⁺) == (180, 181) + @test w ≈ 0.5 - # Find expected nearest indices - λnodes_arr = λnodes(intermediate_grid, Center(); with_halos=false) - φnodes_arr = φnodes(intermediate_grid, Center(); with_halos=false) - i★ = argmin(abs.(λnodes_arr .- 2.5)) - j★ = argmin(abs.(φnodes_arr .- 1.5)) + # Wrap cell: x just below the period boundary. + i⁻, i⁺, w = bracket_with_weight(coords, 359.99; period = 360) + @test (i⁻, i⁺) == (n, 1) + @test 0 < w < 1 - @allowscalar begin - for k in 1:2 - @test column_field[1, 1, k] == intermediate_field[i★, j★, k] - end - end - end + # x past the period: mod wraps it back into the regular range. + i⁻, i⁺, w = bracket_with_weight(coords, 360.5; period = 360) + @test i⁻ == 1 + @test i⁺ == 2 || (i⁻ == n && i⁺ == 1) # right at coords[1] boundary +end - @testset "Nearest extraction preserves vertical profile on $A" begin - intermediate_grid = LatitudeLongitudeGrid(arch; - size = (3, 3, 5), - longitude = (10, 13), - latitude = (40, 43), - z = (-50, 0)) +@testset "infer_longitudinal_period" begin + @test infer_longitudinal_period(collect(0.5:1.0:359.5)) == 360 + @test infer_longitudinal_period(collect(-179.75:0.5:179.75)) == 360 + @test infer_longitudinal_period([10.0, 11.0, 12.0]) === nothing + @test infer_longitudinal_period([100.0]) === nothing +end - intermediate_field = CenterField(intermediate_grid) +@testset "NaN-aware blend" begin + # 2x2x1 synthetic data; column at the centre point with equal weights. + c = ColumnInfo(1, 2, 1, 2, 0.5f0, 0.5f0, Linear()) + FT = Float32 - # Set a vertical profile: value = depth level - for k in 1:5 - interior(intermediate_field)[:, :, k] .= Float64(k) - end - fill_halo_regions!(intermediate_field) + # All-valid: result is the simple average. + data_full = reshape(Float32[1 2; 3 4], 2, 2, 1) + @test blend(data_full, c, 1, nothing, c.ℑ, FT) ≈ 2.5f0 - col = Column(11.5, 41.5; interpolation=Nearest()) - column_grid = RectilinearGrid(arch; - size = 5, - x = 11.5, - y = 41.5, - z = (-50, 0), - halo = 3, - topology = (Flat, Flat, Bounded)) + # All-NaN: result is NaN. + data_nan = fill(NaN32, 2, 2, 1) + @test isnan(blend(data_nan, c, 1, nothing, c.ℑ, FT)) - column_field = Field{Nothing, Nothing, Center}(column_grid) - extract_column!(column_field, intermediate_field, col, Nearest()) + # Partial: bottom-right corner is NaN, weights renormalise over the rest. + # Weights become (0.25, 0.25, 0.25, 0); Σw = 0.75; sum = 1+2+3 = 6; + # result = 6 / 0.75 = 2.0 in 1/2/3 → renormalised mean. + data_part = reshape(Float32[1 2; 3 NaN32], 2, 2, 1) + @test blend(data_part, c, 1, nothing, c.ℑ, FT) ≈ 2.0f0 - @allowscalar begin - for k in 1:5 - @test column_field[1, 1, k] == k - end - end - end - end + # Missing values from NetCDF-style arrays are treated as NaN. + data_missing = reshape(Union{Missing, Float32}[1.0 2.0; 3.0 missing], 2, 2, 1) + @test blend(data_missing, c, 1, nothing, c.ℑ, FT) ≈ 2.0f0 end -@testset "extract_column! dispatch routes on interpolation type" begin - for arch in test_architectures - A = typeof(arch) - @testset "Dispatch on $A" begin - intermediate_grid = LatitudeLongitudeGrid(arch; - size = (4, 4, 2), - longitude = (0, 4), - latitude = (0, 4), - z = (-20, 0)) - - intermediate_field = CenterField(intermediate_grid) - interior(intermediate_field) .= 42.0 - fill_halo_regions!(intermediate_field) - - column_grid = RectilinearGrid(arch; - size = 2, - x = 2.0, - y = 2.0, - z = (-20, 0), - halo = 3, - topology = (Flat, Flat, Bounded)) - - # Column dispatch routes to the correct method - col_nearest = Column(2.0, 2.0; interpolation=Nearest()) - cf = Field{Nothing, Nothing, Center}(column_grid) - extract_column!(cf, intermediate_field, col_nearest) +@testset "blend dispatches Linear vs Nearest" begin + data = reshape(Float32[1 2; 3 4], 2, 2, 1) + FT = Float32 - @allowscalar begin - @test cf[1, 1, 1] == 42.0 - @test cf[1, 1, 2] == 42.0 - end - end - end + # wx = wy = 0.5 → average for Linear; arbitrary corner for Nearest. + c_lin = ColumnInfo(1, 2, 1, 2, 0.5f0, 0.5f0, Linear()) + @test blend(data, c_lin, 1, nothing, c_lin.ℑ, FT) ≈ 2.5f0 + + # wx = 0.7, wy = 0.7 → both above 0.5 → picks i⁺, j⁺ = data[2,2,1] = 4. + c_near = ColumnInfo(1, 2, 1, 2, 0.7f0, 0.7f0, Nearest()) + @test blend(data, c_near, 1, nothing, c_near.ℑ, FT) ≈ 4.0f0 + + # wx = 0.3, wy = 0.3 → both below 0.5 → picks i⁻, j⁻ = data[1,1,1] = 1. + c_near2 = ColumnInfo(1, 2, 1, 2, 0.3f0, 0.3f0, Nearest()) + @test blend(data, c_near2, 1, nothing, c_near2.ℑ, FT) ≈ 1.0f0 end @testset "End-to-end Column Field construction" begin @@ -217,8 +199,9 @@ end grid = native_grid(md) @test grid isa RectilinearGrid - @test topology(grid) == (Flat, Flat, Bounded) - # ERA5 has z = (0, 1), single level + # ERA5HourlySingleLevel is a 2D surface dataset — Column grids collapse + # the (absent) vertical axis to Flat as well. + @test topology(grid) == (Flat, Flat, Flat) @test size(grid) == (1, 1, 1) end @@ -235,32 +218,49 @@ end @testset "restrict (BoundingBox grid construction helper)" begin restrict = NumericalEarth.DataWrangling.restrict - # Identity case: bbox covers the full domain. Grid pads by Δ/2 on each side - # so that face midpoints land on data centers. The padded extent is N+1 - # cells of width Δ exactly, but Float64 rounding can push the ceil one cell - # past that — so allow [N+1, N+2]. - grid_interfaces, rN = restrict((0.0, 360.0), (0.0, 360.0), 1440) - @test grid_interfaces[1] ≈ -0.125 - @test grid_interfaces[2] ≈ 360.125 - @test 1441 <= rN <= 1442 - - # Half-domain bbox: rN should be just over half of N. - _, rN = restrict((0.0, 180.0), (0.0, 360.0), 1440) - @test 720 < rN <= 722 # ceil(0.5 * 1440 + small) = 721 - - # Small bbox (5° wide on a 1440-cell grid): rN should be ceil(20 + small) = 21. - grid_interfaces, rN = restrict((0.0, 5.0), (0.0, 360.0), 1440) - @test grid_interfaces[1] ≈ -0.125 - @test grid_interfaces[2] ≈ 5.125 + # Uniform: snap bbox outward to native faces (Δ = 0.25 throughout). + grid_interfaces, rN = restrict((0.0, 5.0), (-0.125, 359.875), 1440) + @test grid_interfaces == (-0.125, 5.125) @test rN == 21 - # Off-origin bbox preserves width: 5° wide on a 720-cell, 180°-tall grid → - # rΔ = 5° + Δ = 5.25°, rN = ceil((5.25/180) * 720) = 21. - _, rN_off = restrict((40.0, 45.0), (-90.0, 90.0), 720) - @test rN_off == 21 + # Already face-aligned — snap is a no-op. + grid_interfaces, rN = restrict((-0.125, 5.125), (-0.125, 359.875), 1440) + @test grid_interfaces == (-0.125, 5.125) + @test rN == 21 + + grid_interfaces, rN_off = restrict((40.0, 45.0), (-90.0, 90.0), 720) + @test grid_interfaces == (40.0, 45.0) + @test rN_off == 20 + + grid_interfaces, rN = restrict((0.0, 360.0), (0.0, 360.0), 1440) + @test grid_interfaces == (0.0, 360.0) + @test rN == 1440 + + # Sub-cell bbox still spans ≥ 1 native cell. + grid_interfaces, rN_tiny = restrict((0.1, 0.15), (-0.125, 359.875), 1440) + @test grid_interfaces == (-0.125, 0.375) + @test rN_tiny == 2 + + # Past-bounds bbox is clamped to native faces. + grid_interfaces, rN = restrict((-1.0, 1.5), (-0.125, 359.875), 1440) + @test grid_interfaces[1] == -0.125 + @test grid_interfaces[2] == 1.625 + @test rN == 7 + + # Stretched (Vector) interfaces. + interfaces = collect(0.0:1.0:10.0) + grid_interfaces, rN = restrict((2.0, 5.0), interfaces, 10) + @test grid_interfaces == [2.0, 3.0, 4.0, 5.0] + @test rN == 3 + grid_interfaces, rN = restrict((2.3, 4.7), interfaces, 10) + @test grid_interfaces == [2.0, 3.0, 4.0, 5.0] + @test rN == 3 + grid_interfaces, rN = restrict((-1.0, 1.5), interfaces, 10) + @test first(grid_interfaces) == 0.0 + @test rN >= 1 - # Pass-through for `nothing` (the no-restriction case). @test restrict(nothing, (0.0, 360.0), 1440) == ((0.0, 360.0), 1440) + @test restrict(nothing, interfaces, 10) == (interfaces, 10) end @testset "restrict_location dispatch" begin diff --git a/test/test_dataset_region.jl b/test/test_dataset_region.jl new file mode 100644 index 000000000..92e151921 --- /dev/null +++ b/test/test_dataset_region.jl @@ -0,0 +1,63 @@ +include("runtests_setup.jl") + +using NumericalEarth.DataWrangling: BoundingBox, Column, metadata_path +using Oceananigans: λnodes, φnodes, Center, interior +using Oceananigans.Fields: Field +using NCDatasets +using Dates + +@testset "Cross-dataset region support (snapshot path)" begin + arch = CPU() + + @testset "ECCO4 BoundingBox loads the right window" begin + # ECCO4Monthly stores longitude on [-180, 180]; pick a bbox far from + # the file's SW corner (which is the Antarctic ocean, all zeros) so + # that positional vs. coordinate indexing differ. + bbox = BoundingBox(longitude=(-60, 60), latitude=(-30, 30)) + md = Metadatum(:temperature; dataset=ECCO4Monthly(), + date=DateTime(1993, 1, 1), region=bbox) + f = Field(md, arch; inpainting=nothing, cache_inpainted_data=false) + + # Grid coordinates must fall inside the requested bbox (with a 1° + # tolerance for ECCO4's 0.5° spacing). + λg = λnodes(f.grid, Center()) + φg = φnodes(f.grid, Center()) + @test minimum(λg) ≥ -60 - 1.0 + @test maximum(λg) ≤ 60 + 1.0 + @test minimum(φg) ≥ -30 - 1.0 + @test maximum(φg) ≤ 30 + 1.0 + @test any(!iszero, interior(f)) + + # Correctness: hand-extract the reference value at (0°, 0°, surface) + # directly from the NetCDF file and compare against the bbox-restricted + # field at the same physical point. This catches the silent SW-corner + # positional-indexing bug in `set_metadata_field!`. + path = metadata_path(md) + ds = Dataset(path) + # ECCO4 surface is k=1 in the raw file (Z=-5 m). After + # `retrieve_data` reverses dims=3, surface becomes k=end in + # `interior(f)`. + T_full = ds["THETA"][:, :, 1, 1] + λfile = ds["longitude"][:] + φfile = ds["latitude"][:] + close(ds) + i★ = argmin(abs.(λfile .- 0)) + j★ = argmin(abs.(φfile .- 0)) + T_ref = T_full[i★, j★] + # Same physical point in the bbox-restricted field. + i_grid = argmin(abs.(λg .- 0)) + j_grid = argmin(abs.(φg .- 0)) + T_field = interior(f)[i_grid, j_grid, end] + @test T_field ≈ T_ref rtol=1e-3 + end + + @testset "ECCO4 Column extracts a single point" begin + col = Column(150.0, 0.0) + md = Metadatum(:temperature; dataset=ECCO4Monthly(), + date=DateTime(1993, 1, 1), region=col) + f = Field(md, arch; inpainting=nothing, cache_inpainted_data=false) + @test size(f.grid, 1) == 1 + @test size(f.grid, 2) == 1 + @test any(!iszero, interior(f)) + end +end diff --git a/test/test_jra55.jl b/test/test_jra55.jl index 3c9e75bc4..33b1fe0a5 100644 --- a/test/test_jra55.jl +++ b/test/test_jra55.jl @@ -15,7 +15,7 @@ using NumericalEarth.DataWrangling: compute_native_date_range dates = NumericalEarth.DataWrangling.all_dates(JRA55.RepeatYearJRA55(), test_name) end_date = dates[3] - JRA55_fts = JRA55FieldTimeSeries(test_name, arch; end_date) + JRA55_fts = FieldTimeSeries(Metadata(test_name; dataset=JRA55.RepeatYearJRA55(), end_date), arch) test_filename = joinpath(download_JRA55_cache, "RYF.rsds.1990_1991.nc") @test JRA55_fts isa FieldTimeSeries @@ -41,8 +41,8 @@ using NumericalEarth.DataWrangling: compute_native_date_range @info "Testing Cyclical time_indices for JRA55 data on $A..." Nb = 4 - backend = JRA55NetCDFBackend(Nb) - netcdf_JRA55_fts = JRA55FieldTimeSeries(test_name, arch; backend) + netcdf_JRA55_fts = FieldTimeSeries(Metadata(test_name; dataset=JRA55.RepeatYearJRA55()), arch; + time_indices_in_memory=Nb) Nt = length(netcdf_JRA55_fts.times) @test Oceananigans.OutputReaders.time_indices(netcdf_JRA55_fts) == (1, 2, 3, 4) @@ -58,12 +58,33 @@ using NumericalEarth.DataWrangling: compute_native_date_range @test f₁ == f₁′ end + @info "Testing Field(::JRA55Metadatum) on $A..." + # Locks in: position-based RepeatYear file_index lookup, halo + # periodic wrap, and agreement between the per-Metadatum Field + # path and the chunked-file FTS path at the same time index. + ds_var = :downwelling_shortwave_radiation + all_jra55_dates = NumericalEarth.DataWrangling.all_dates(JRA55.RepeatYearJRA55(), ds_var) + + md_first = Metadatum(ds_var; dataset=JRA55.RepeatYearJRA55()) + f_first = Field(md_first, arch) + @test f_first isa Field + @test size(f_first) == (640, 320, 1) + CUDA.@allowscalar @test f_first[1, 1, 1] == 430.98105f0 + CUDA.@allowscalar @test view(f_first.data, 1, :, 1) == view(f_first.data, 641, :, 1) + + md_mid = Metadatum(ds_var; dataset=JRA55.RepeatYearJRA55(), date=all_jra55_dates[100]) + f_mid = Field(md_mid, arch) + # Same time index loaded via the chunked-file FTS path → must agree + fts100 = FieldTimeSeries(Metadata(ds_var; dataset=JRA55.RepeatYearJRA55(), end_date=all_jra55_dates[100]), arch; time_indices_in_memory=100) + CUDA.@allowscalar @test f_mid[1, 1, 1] == fts100[1, 1, 1, 100] + CUDA.@allowscalar @test f_mid[640, 1, 1] == fts100[640, 1, 1, 100] + @info "Testing interpolate_field_time_series! on $A..." name = :downwelling_shortwave_radiation dates = NumericalEarth.DataWrangling.all_dates(JRA55.RepeatYearJRA55(), name) end_date = dates[3] - JRA55_fts = JRA55FieldTimeSeries(name, arch; end_date) + JRA55_fts = FieldTimeSeries(Metadata(name; dataset=JRA55.RepeatYearJRA55(), end_date), arch; time_indices_in_memory=3) # Make target grid and field resolution = 1 # degree, eg 1/4 @@ -116,12 +137,11 @@ using NumericalEarth.DataWrangling: compute_native_date_range ##### JRA55 prescribed atmosphere ##### - backend = JRA55NetCDFBackend(2) - atmosphere = JRA55PrescribedAtmosphere(arch; backend) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=2) @test atmosphere isa PrescribedAtmosphere # Test JRA55PrescribedLand loads river and iceberg data with correct frequency - land = JRA55PrescribedLand(arch; backend) + land = JRA55PrescribedLand(arch; time_indices_in_memory=2) @test land isa NumericalEarth.Lands.PrescribedLand @test haskey(land.freshwater_flux, :rivers) @test haskey(land.freshwater_flux, :icebergs) @@ -140,8 +160,6 @@ using NumericalEarth.DataWrangling: compute_native_date_range start_date = DateTime("1959-01-01T00:00:00") - 15 * Day(1) # sometime in 1958 end_date = DateTime("1959-01-01T00:00:00") + 85 * Day(1) # sometime in 1959 - backend = JRA55NetCDFBackend(10) - # Use a temporary directory so different architectures don't clash mktempdir("./") do dir # Compute expected file paths so we can fall back to artifacts if needed @@ -151,7 +169,7 @@ using NumericalEarth.DataWrangling: compute_native_date_range filepaths = unique(metadata_path(metadata)) Ta = download_dataset_with_fallback(filepaths; dataset_name="MultiYearJRA55 :temperature") do - JRA55FieldTimeSeries(:temperature, arch; dataset, start_date, end_date, backend, dir) + FieldTimeSeries(metadata, arch; time_indices_in_memory=10) end @test Second(end_date - start_date).value ≈ Ta.times[end] - Ta.times[1] @@ -160,5 +178,36 @@ using NumericalEarth.DataWrangling: compute_native_date_range @test Ta[t] isa Field end end + + @info "Testing MultiYearJRA55 single-window crossing year boundary on $A..." + + # Force a single in-memory window to straddle the 1958 → 1959 file + # boundary. Before the per-file `ftsn_loc` fix, the second file's + # iteration in `set!` would clobber the outer `ftsn` and write to the + # wrong slots; this regression test would then leave some in-memory + # slots untouched (zero-valued). + start_date_span = DateTime("1958-12-27T00:00:00") + end_date_span = DateTime("1959-01-05T00:00:00") + + mktempdir("./") do dir + native_dates = NumericalEarth.DataWrangling.all_dates(dataset, :temperature) + dates = compute_native_date_range(native_dates, start_date_span, end_date_span) + metadata = Metadata(:temperature; dataset, dates, dir) + filepaths = unique(metadata_path(metadata)) + + Ta_span = download_dataset_with_fallback(filepaths; + dataset_name="MultiYearJRA55 :temperature year-boundary window") do + # backend window of 80 holds the whole range in a single window + FieldTimeSeries(metadata, arch; time_indices_in_memory=80) + end + + # Every slot in the single in-memory window must carry valid + # (non-zero) atmospheric temperature data. + CUDA.@allowscalar begin + for t in eachindex(Ta_span.times) + @test maximum(abs, interior(Ta_span[t])) > 0 + end + end + end end end diff --git a/test/test_jra55_ecco_en4_etopo_downloading.jl b/test/test_jra55_ecco_en4_etopo_downloading.jl index 7fcd841ff..ff3f1f723 100644 --- a/test/test_jra55_ecco_en4_etopo_downloading.jl +++ b/test/test_jra55_ecco_en4_etopo_downloading.jl @@ -8,11 +8,11 @@ include("download_utils.jl") datum = Metadatum(name; dataset=JRA55.RepeatYearJRA55()) filepath = metadata_path(datum) - fts = download_dataset_with_fallback(filepath; dataset_name="JRA55 $name") do - NumericalEarth.JRA55.JRA55FieldTimeSeries(name; backend=NumericalEarth.JRA55.JRA55NetCDFBackend(2)) + download_dataset_with_fallback(filepath; dataset_name="JRA55 $name") do + FieldTimeSeries(Metadata(name; dataset=NumericalEarth.JRA55.RepeatYearJRA55()); time_indices_in_memory=2) end - @test isfile(fts.path) - rm(fts.path; force=true) + @test isfile(filepath) + rm(filepath; force=true) end end diff --git a/test/test_jra55_region.jl b/test/test_jra55_region.jl new file mode 100644 index 000000000..506803204 --- /dev/null +++ b/test/test_jra55_region.jl @@ -0,0 +1,58 @@ +include("runtests_setup.jl") + +using NumericalEarth.DataWrangling: BoundingBox, Column, Linear +using Oceananigans.Fields: interpolate as oc_interpolate +using Oceananigans.Grids: topology, Bounded + +@testset "JRA55 region support" begin + arch = CPU() + + @testset "BoundingBox slices the right window" begin + bbox = BoundingBox(longitude=(120, 240), latitude=(-30, 30)) + atm = JRA55PrescribedAtmosphere(arch; + time_indices_in_memory=2, + region=bbox) + Ta = atm.tracers.T + # Coordinates of the field grid should fall inside the requested bbox. + λnodes_T = λnodes(Ta.grid, Center()) + φnodes_T = φnodes(Ta.grid, Center()) + @test minimum(λnodes_T) ≥ 120 - 1.5 # 1.5° JRA55 spacing tolerance + @test maximum(λnodes_T) ≤ 240 + 1.5 + @test minimum(φnodes_T) ≥ -30 - 1.5 + @test maximum(φnodes_T) ≤ 30 + 1.5 + @test any(!iszero, interior(Ta)) + @test !any(isnan, interior(Ta)) + # Sub-360° span must be Bounded in x so halos do not wrap. + @test topology(Ta.grid)[1] == Bounded + end + + @testset "Column extracts a single point" begin + col = Column(150.0, 0.0) # equator, central Pacific + atm = JRA55PrescribedAtmosphere(arch; + time_indices_in_memory=2, + region=col) + Ta = atm.tracers.T + @test size(Ta.grid, 1) == 1 + @test size(Ta.grid, 2) == 1 + @test any(!iszero, interior(Ta)) + @test !any(isnan, interior(Ta)) + end + + @testset "Column matches bbox bilinear at the column point" begin + # The Column dispatch should produce the same value as bilinearly + # interpolating a bbox-extracted FTS at the column's (lon, lat). + col_atm = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=2, + region=Column(150.0, 0.0)) + bbox_atm = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=2, + region=BoundingBox(longitude=(148, 152), + latitude=(-2, 2))) + + Ta_col = col_atm.tracers.T + Ta_bbox = bbox_atm.tracers.T + + T_col_t1 = interior(Ta_col)[1, 1, 1, 1] + loc = (Center(), Center(), Center()) + T_bbox_t1 = oc_interpolate((150.0, 0.0, 0.0), Ta_bbox[1], loc, Ta_bbox.grid) + @test T_col_t1 ≈ T_bbox_t1 rtol = 1e-3 + end +end diff --git a/test/test_mangling.jl b/test/test_mangling.jl new file mode 100644 index 000000000..bd2886ea4 --- /dev/null +++ b/test/test_mangling.jl @@ -0,0 +1,48 @@ +include("runtests_setup.jl") + +using NumericalEarth.DataWrangling: mangle, mangling_for, ShiftSouth, AverageNorthSouth + +using Oceananigans.Fields: location + +@testset "mangle dispatch" begin + data = reshape(Float32[1 2 3; 4 5 6; 7 8 9], 3, 3, 1) + + # Identity: reads (i, j, k) directly. + @test mangle(2, 2, 1, data, nothing) == 5 + + # ShiftSouth: reads (i, j-1, k); j=1 clamps to j=1. + @test mangle(2, 2, 1, data, ShiftSouth()) == 4 + @test mangle(2, 1, 1, data, ShiftSouth()) == 4 + + # AverageNorthSouth: averages (i, j, k) and (i, j+1, k). + @test mangle(2, 1, 1, data, AverageNorthSouth()) ≈ 4.5f0 + @test mangle(2, 2, 1, data, AverageNorthSouth()) ≈ 5.5f0 +end + +@testset "mangling_for size dispatch" begin + md = Metadatum(:v_velocity; dataset=ECCO4Monthly(), date=start_date) + _, Ny, _, _ = size(md) + + @test mangling_for(md, Ny) === nothing + @test mangling_for(md, Ny - 1) isa ShiftSouth + @test mangling_for(md, Ny + 1) isa AverageNorthSouth + @test mangling_for(md, Ny - 2) === nothing + @test mangling_for(md, Ny + 2) === nothing +end + +@testset "ECCO v_velocity Field uses ShiftSouth mangling end-to-end" begin + for arch in test_architectures + md = Metadatum(:v_velocity; dataset=ECCO4Monthly(), date=start_date) + field = Field(md, arch) + + # v lives on the latitude Face for ECCO; the file ships Ny-1 lat + # entries, so mangling_for must select ShiftSouth — anything else + # would either go out of bounds or read off-by-one everywhere. + @test location(field) == (Center, Face, Center) + + @allowscalar begin + @test any(!=(0), interior(field)) + @test any(isfinite, interior(field)) + end + end +end diff --git a/test/test_metadata.jl b/test/test_metadata.jl index e55fd2fdf..389794cd9 100644 --- a/test/test_metadata.jl +++ b/test/test_metadata.jl @@ -3,6 +3,7 @@ include("runtests_setup.jl") using NumericalEarth.DataWrangling: Column, Linear, Nearest, BoundingBox, dataset_location, restrict_location, native_grid +using NumericalEarth.DataWrangling: restrict using Oceananigans: RectilinearGrid, LatitudeLongitudeGrid, location using Oceananigans.Grids: topology, Flat, Bounded, Periodic @@ -111,6 +112,19 @@ end Nx, Ny, Nz = size(grid) @test Nx < Nx_full @test Ny < Ny_full + + # Sub-360° bbox must be Bounded in x (not Periodic) so halos don't wrap. + @test topology(grid)[1] == Bounded + + # 360°-spanning bbox keeps Periodic in x. + bbox_full = BoundingBox(longitude=(-180, 180), latitude=(-30, 30)) + md_full = Metadatum(:temperature; dataset=ECCO4Monthly(), region=bbox_full) + @test topology(native_grid(md_full))[1] == Periodic + + # Latitude-only restriction: longitude is unrestricted, x stays Periodic. + bbox_lat = BoundingBox(latitude=(-30, 30)) + md_lat = Metadatum(:temperature; dataset=ECCO4Monthly(), region=bbox_lat) + @test topology(native_grid(md_lat))[1] == Periodic end @testset "Metadata region keyword" begin @@ -141,3 +155,39 @@ end @test last(md).region === col @test md[1].region === col end + +@testset "restrict() snaps to native interfaces" begin + # Uniform interfaces: snapping coincides with the user's bbox if it + # already lies on cell boundaries. + interfaces = collect(0.0:1.0:10.0) + sliced, rN = restrict((2.0, 6.0), interfaces, 10) + @test sliced == [2.0, 3.0, 4.0, 5.0, 6.0] + @test rN == 4 + + # Uniform interfaces, off-grid bbox: snap outward to the surrounding + # native cells so the result is a superset of the request. + sliced, rN = restrict((2.5, 6.5), interfaces, 10) + @test sliced[1] ≤ 2.5 + @test sliced[end] ≥ 6.5 + @test rN == length(sliced) - 1 + + # Stretched interfaces (cells get wider): snapping must return the + # actual native interfaces, not a 2-tuple of the user's bbox. + stretched = [0.0, 0.5, 1.5, 3.0, 5.5, 9.5, 15.0] + sliced, rN = restrict((1.0, 6.0), stretched, length(stretched) - 1) + @test sliced == [0.5, 1.5, 3.0, 5.5, 9.5] + @test rN == 4 + + # Out-of-range bbox is clamped, not crashed. + sliced, rN = restrict((-100.0, 100.0), stretched, length(stretched) - 1) + @test sliced == stretched + @test rN == length(stretched) - 1 + + # 2-tuple endpoints: uniform native grids return the bbox endpoints + # verbatim (no snap) with a proportional cell count. Stays correct across + # longitude conventions for pre-subsetted files (e.g. GLORYS via Copernicus); + # the centre alignment is handled at read time by `region_info`. + sliced, rN = restrict((120, 240), (0, 360), 360) + @test sliced == (120, 240) + @test rN == 120 +end diff --git a/test/test_ocean_only_model.jl b/test/test_ocean_only_model.jl index 6254f9e63..35f2f7211 100644 --- a/test/test_ocean_only_model.jl +++ b/test/test_ocean_only_model.jl @@ -17,8 +17,7 @@ using Oceananigans.OrthogonalSphericalShellGrids data = Int[] pushdata(sim) = push!(data, iteration(sim)) add_callback!(ocean, pushdata) - backend = JRA55NetCDFBackend(4) - atmosphere = JRA55PrescribedAtmosphere(arch; backend) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4) radiation = Radiation(arch) coupled_model = OceanOnlyModel(ocean; atmosphere, radiation) Δt = 60 @@ -48,8 +47,7 @@ using Oceananigans.OrthogonalSphericalShellGrids free_surface = SplitExplicitFreeSurface(grid; substeps=20) ocean = ocean_simulation(grid; free_surface) - backend = JRA55NetCDFBackend(4) - atmosphere = JRA55PrescribedAtmosphere(arch; backend) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4) radiation = Radiation(arch) # Fluxes are computed when the model is constructed, so we just test that this works. @@ -64,7 +62,8 @@ using Oceananigans.OrthogonalSphericalShellGrids ##### @info "Testing OceanOnlyModel with JRA55PrescribedLand on $A..." - land = JRA55PrescribedLand(arch; backend) + land_dates = all_dates(RepeatYearJRA55(), :river_freshwater_flux) + land = JRA55PrescribedLand(arch; end_date=land_dates[2]) @test begin ocean_with_land = ocean_simulation(grid; free_surface) diff --git a/test/test_ocean_sea_ice_model.jl b/test/test_ocean_sea_ice_model.jl index 4188435ec..268ba339f 100644 --- a/test/test_ocean_sea_ice_model.jl +++ b/test/test_ocean_sea_ice_model.jl @@ -57,8 +57,7 @@ using ClimaSeaIce.Rheologies Tm = KernelFunctionOperation{Center, Center, Center}(kernel_melting_temperature, grid, liquidus, S) @test all(T .≥ Tm) - backend = JRA55NetCDFBackend(4) - atmosphere = JRA55PrescribedAtmosphere(arch; backend) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4) radiation = Radiation(arch) # Fluxes are computed when the model is constructed, so we just test that this works. @@ -71,7 +70,8 @@ using ClimaSeaIce.Rheologies # Test with land component @info "Testing OceanSeaIceModel with land on $A..." - land = JRA55PrescribedLand(arch; backend) + land_dates = all_dates(RepeatYearJRA55(), :river_freshwater_flux) + land = JRA55PrescribedLand(arch; end_date=land_dates[2]) @test begin ocean_with_land = ocean_simulation(grid; free_surface) diff --git a/test/test_sea_ice_ocean_heat_fluxes.jl b/test/test_sea_ice_ocean_heat_fluxes.jl index 7506d3d78..df7ce565a 100644 --- a/test/test_sea_ice_ocean_heat_fluxes.jl +++ b/test/test_sea_ice_ocean_heat_fluxes.jl @@ -199,8 +199,7 @@ end ocean = ocean_simulation(grid, momentum_advection=nothing, closure=nothing, tracer_advection=nothing) sea_ice = sea_ice_simulation(grid, ocean) - backend = JRA55NetCDFBackend(4) - atmosphere = JRA55PrescribedAtmosphere(arch; backend) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4) radiation = Radiation(arch) for sea_ice_ocean_heat_flux in [IceBathHeatFlux(), ThreeEquationHeatFlux()] @@ -257,8 +256,7 @@ end ocean = ocean_simulation(grid, momentum_advection=nothing, closure=nothing, tracer_advection=nothing) sea_ice = sea_ice_simulation(grid, ocean) - backend = JRA55NetCDFBackend(4) - atmosphere = JRA55PrescribedAtmosphere(arch; backend) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4) radiation = Radiation(arch) for sea_ice_ocean_heat_flux in [IceBathHeatFlux(), ThreeEquationHeatFlux()] @@ -401,8 +399,7 @@ end ocean = ocean_simulation(grid, momentum_advection=nothing, closure=nothing, tracer_advection=nothing) sea_ice = sea_ice_simulation(grid, ocean) - backend = JRA55NetCDFBackend(4) - atmosphere = JRA55PrescribedAtmosphere(arch; backend) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4) radiation = Radiation(arch) for sea_ice_ocean_heat_flux in [IceBathHeatFlux(), ThreeEquationHeatFlux()] @@ -452,8 +449,7 @@ end ocean = ocean_simulation(grid, momentum_advection=nothing, closure=nothing, tracer_advection=nothing) sea_ice = sea_ice_simulation(grid, ocean) - backend = JRA55NetCDFBackend(4) - atmosphere = JRA55PrescribedAtmosphere(arch; backend) + atmosphere = JRA55PrescribedAtmosphere(arch; time_indices_in_memory=4) radiation = Radiation(arch) # Test with ThreeEquationHeatFlux (default) diff --git a/test/test_surface_fluxes.jl b/test/test_surface_fluxes.jl index 3ffdd7b60..632fbbdd6 100644 --- a/test/test_surface_fluxes.jl +++ b/test/test_surface_fluxes.jl @@ -32,7 +32,7 @@ end @testset "Test surface fluxes" begin for arch in test_architectures - grid = LatitudeLongitudeGrid(arch; + grid = LatitudeLongitudeGrid(arch, Float32; size = 1, latitude = 10, longitude = 10, @@ -46,14 +46,14 @@ end bottom_drag_coefficient = 0) dates = all_dates(RepeatYearJRA55(), :temperature) - atmosphere = JRA55PrescribedAtmosphere(arch, Float64; end_date=dates[2], backend = InMemory()) + atmosphere = JRA55PrescribedAtmosphere(arch; end_date=dates[2]) CUDA.@allowscalar begin h = atmosphere.surface_layer_height pᵃᵗ = atmosphere.pressure[1][1, 1, 1] Tᵃᵗ = 15 + celsius_to_kelvin - qᵃᵗ = 0.003 + qᵃᵗ = Float32(0.003) uᵃᵗ = atmosphere.velocities.u[1][1, 1, 1] vᵃᵗ = atmosphere.velocities.v[1][1, 1, 1] @@ -196,7 +196,7 @@ end set!(ocean_with_land.model, T = 15, S = 30) land_dates = all_dates(RepeatYearJRA55(), :river_freshwater_flux) - land = JRA55PrescribedLand(arch; end_date=land_dates[2], backend = InMemory()) + land = JRA55PrescribedLand(arch; end_date=land_dates[2]) model_with_land = OceanOnlyModel(ocean_with_land; atmosphere, land) # Verify land exchanger is wired up @@ -232,7 +232,7 @@ end bottom_drag_coefficient = 0.0) dates = all_dates(RepeatYearJRA55(), :temperature) - atmosphere = JRA55PrescribedAtmosphere(arch; end_date=dates[2], backend = InMemory()) + atmosphere = JRA55PrescribedAtmosphere(arch; end_date=dates[2]) fill!(ocean.model.tracers.T, -2.0)