diff --git a/Project.toml b/Project.toml index f77b0bc..6ed35ab 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "XESMF" uuid = "2e0b0046-e7a1-486f-88de-807ee8ffabe5" authors = ["NumericalEarth and contributors"] -version = "0.1.3" +version = "0.1.4" [deps] CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" diff --git a/src/XESMF.jl b/src/XESMF.jl index 84f90d6..b25694c 100644 --- a/src/XESMF.jl +++ b/src/XESMF.jl @@ -12,7 +12,7 @@ struct Regridder{S, M, V1, V2} dst_temp :: V2 end -Base.summary(r::Regridder{S, M, V1, V2}) where {S, M, V1, V2} = "$(r.method) Regridder" +Base.summary(r::Regridder{S, M, V1, V2}) where {S, M, V1, V2} = "$(uppercasefirst(r.method)) Regridder" function Base.show(io::IO, r::Regridder) print(io, summary(r), '\n') @@ -41,13 +41,28 @@ function sparse_regridder_weights(FT, regridder) return weights end +# The easy case, for dense vectors, the regridding operation defaults to a matrix multiplication +(regridder::Regridder)(dst::DenseVector, src::DenseVector) = LinearAlgebra.mul!(dst, regridder.weights, src) + # Generic regridding function that does not check the dimensions of the `src` and -# `dst` arrays -function regrid!(dst::AbstractVector, regridder::Regridder, src::AbstractVector) +# `dst` arrays. For general vectors that might be discontinuous in memory, we need +# to broadcast the value to a `DenseArray` before performing the sparse matrix multiply +function (regridder::Regridder)(dst::AbstractVector, src::AbstractVector) regridder.src_temp .= src LinearAlgebra.mul!(regridder.dst_temp, regridder.weights, regridder.src_temp) dst .= regridder.dst_temp + return dst +end +# Mixed cases +function (regridder::Regridder)(dst::DenseVector, src::AbstractVector) + regridder.src_temp .= src + return LinearAlgebra.mul!(dst, regridder.weights, regridder.src_temp) +end + +function (regridder::Regridder)(dst::AbstractVector, src::DenseVector) + LinearAlgebra.mul!(regridder.dst_temp, regridder.weights, src) + dst .= regridder.dst_temp return dst end diff --git a/test/runtests.jl b/test/runtests.jl index d6fb54c..60c158f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ include("setup_runtests.jl") @testset "Unit Tests" begin include("test_unit.jl") end - + @testset "Oceananigans Integration Tests" begin include("test_oceananigans.jl") end diff --git a/test/setup_runtests.jl b/test/setup_runtests.jl index 9a42277..562e8de 100644 --- a/test/setup_runtests.jl +++ b/test/setup_runtests.jl @@ -1,4 +1,3 @@ using Test -using Oceananigans using PythonCall using XESMF \ No newline at end of file diff --git a/test/test_oceananigans.jl b/test/test_oceananigans.jl index 095def5..9f6daaf 100644 --- a/test/test_oceananigans.jl +++ b/test/test_oceananigans.jl @@ -1,25 +1,31 @@ include("setup_runtests.jl") + +using Oceananigans +using Oceananigans.Fields: AbstractField using SparseArrays -x_node_array(x::AbstractVector, Nx, Ny) = view(x, 1:Nx) |> Array -y_node_array(x::AbstractVector, Nx, Ny) = view(x, 1:Ny) |> Array -x_node_array(x::AbstractMatrix, Nx, Ny) = view(x, 1:Nx, 1:Ny) |> Array +function x_node_array(x::AbstractVector, Nx, Ny) + return Array(repeat(view(x, 1:Nx), 1, Ny))' +end +function y_node_array(x::AbstractVector, Nx, Ny) + return Array(repeat(view(x, 1:Ny)', Nx, 1))' +end +x_node_array(x::AbstractMatrix, Nx, Ny) = Array(view(x, 1:Nx, 1:Ny))' -x_vertex_array(x::AbstractVector, Nx, Ny) = view(x, 1:Nx+1) |> Array -y_vertex_array(x::AbstractVector, Nx, Ny) = view(x, 1:Ny+1) |> Array -x_vertex_array(x::AbstractMatrix, Nx, Ny) = view(x, 1:Nx+1, 1:Ny+1) |> Array +function x_vertex_array(x::AbstractVector, Nx, Ny) + return Array(repeat(view(x, 1:Nx+1), 1, Ny+1))' +end +function y_vertex_array(x::AbstractVector, Nx, Ny) + return Array(repeat(view(x, 1:Ny+1)', Nx+1, 1))' +end +x_vertex_array(x::AbstractMatrix, Nx, Ny) = Array(view(x, 1:Nx+1, 1:Ny+1))' y_node_array(x::AbstractMatrix, Nx, Ny) = x_node_array(x, Nx, Ny) y_vertex_array(x::AbstractMatrix, Nx, Ny) = x_vertex_array(x, Nx, Ny) -function regridding_weights(dst_field, src_field) +function extract_xesmf_coordinates_structure(dst_field::AbstractField, src_field::AbstractField) ℓx, ℓy, ℓz = Oceananigans.Fields.instantiated_location(src_field) - # We only support regridding between centered fields. - @assert ℓx isa Center - @assert ℓy isa Center - @assert (ℓx, ℓy, ℓz) == Oceananigans.Fields.instantiated_location(dst_field) - dst_grid = dst_field.grid src_grid = src_field.grid @@ -35,7 +41,7 @@ function regridding_weights(dst_field, src_field) λvˢ = λnodes(src_grid, Face(), Face(), ℓz, with_halos=true) φvˢ = φnodes(src_grid, Face(), Face(), ℓz, with_halos=true) - # Build data structures expected by XESMF. + # Build data structures expected by xESMF Nˢx, Nˢy, Nˢz = size(src_field) Nᵈx, Nᵈy, Nᵈz = size(dst_field) @@ -59,7 +65,7 @@ function regridding_weights(dst_field, src_field) "lat_b" => φvˢ, "lon_b" => λvˢ) - return src_coordinates, dst_coordinates + return dst_coordinates, src_coordinates end @testset "Oceananigans Integration Tests" begin @@ -72,7 +78,7 @@ end cll = CenterField(ll) # Test that we can create the coordinate structures - src_coordinates, dst_coordinates = regridding_weights(ctg, cll) + dst_coordinates, src_coordinates = extract_xesmf_coordinates_structure(cll, ctg) # Verify coordinate structures are valid @test haskey(src_coordinates, "lat") @@ -89,12 +95,27 @@ end @test size(tg) == (360, 170, 1) @test size(ll) == (360, 180, 1) - xesmf = XESMF.xesmf - periodic = Oceananigans.Grids.topology(ctg.grid, 1) === Periodic ? PythonCall.pybuiltins.True : pybuiltins.False + periodic = Oceananigans.Grids.topology(ctg.grid, 1) === Periodic ? true : false method = "conservative" - regridder = xesmf.Regridder(src_coordinates, dst_coordinates, method; periodic) - weights = XESMF.sparse_regridder_weights(regridder) + regridder = XESMF.Regridder(src_coordinates, dst_coordinates; method, periodic) + + @test regridder.weights isa SparseMatrixCSC + + # test that the regridder works with dense and strided arrays + dense_tg = zeros(prod(size(tg))) + dense_ll = zeros(prod(size(ll))) + + strided_tg = vec(view(zeros(size(tg, 1)+5, size(tg, 2)+5), 1:size(tg, 1), 1:size(tg, 2))) + strided_ll = vec(view(zeros(size(ll, 1)+5, size(ll, 2)+5), 1:size(ll, 1), 1:size(ll, 2))) + + rand_tg = rand(prod(size(tg))) + rand_ll = rand(prod(size(ll))) + + dense_tg .= rand_tg + strided_tg .= rand_tg + regridder(dense_ll, dense_tg) + regridder(strided_ll, strided_tg) - @test weights isa SparseMatrixCSC + @test all(dense_ll .== strided_ll) end end