Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "XESMF"
uuid = "2e0b0046-e7a1-486f-88de-807ee8ffabe5"
authors = ["NumericalEarth and contributors"]
version = "0.1.3"
version = "0.1.4"

[deps]
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
Expand Down
21 changes: 18 additions & 3 deletions src/XESMF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct Regridder{S, M, V1, V2}
dst_temp :: V2
end

Base.summary(r::Regridder{S, M, V1, V2}) where {S, M, V1, V2} = "$(r.method) Regridder"
Base.summary(r::Regridder{S, M, V1, V2}) where {S, M, V1, V2} = "$(uppercasefirst(r.method)) Regridder"

function Base.show(io::IO, r::Regridder)
print(io, summary(r), '\n')
Expand Down Expand Up @@ -41,13 +41,28 @@ function sparse_regridder_weights(FT, regridder)
return weights
end

# The easy case, for dense vectors, the regridding operation defaults to a matrix multiplication
(regridder::Regridder)(dst::DenseVector, src::DenseVector) = LinearAlgebra.mul!(dst, regridder.weights, src)

# Generic regridding function that does not check the dimensions of the `src` and
# `dst` arrays
function regrid!(dst::AbstractVector, regridder::Regridder, src::AbstractVector)
# `dst` arrays. For general vectors that might be discontinuous in memory, we need
# to broadcast the value to a `DenseArray` before performing the sparse matrix multiply
function (regridder::Regridder)(dst::AbstractVector, src::AbstractVector)
regridder.src_temp .= src
LinearAlgebra.mul!(regridder.dst_temp, regridder.weights, regridder.src_temp)
dst .= regridder.dst_temp
return dst
end

# Mixed cases
function (regridder::Regridder)(dst::DenseVector, src::AbstractVector)
regridder.src_temp .= src
return LinearAlgebra.mul!(dst, regridder.weights, regridder.src_temp)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does LinearAlgebra.mul! return its first arg?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so:

julia> A=[1.0 2.0; 3.0 4.0]; B=[1.0 1.0; 1.0 1.0]; Y = similar(B); mul!(Y, A, B)
2×2 Matrix{Float64}:
 3.0  3.0
 7.0  7.0

end

function (regridder::Regridder)(dst::AbstractVector, src::DenseVector)
LinearAlgebra.mul!(regridder.dst_temp, regridder.weights, src)
dst .= regridder.dst_temp
return dst
end

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ include("setup_runtests.jl")
@testset "Unit Tests" begin
include("test_unit.jl")
end

@testset "Oceananigans Integration Tests" begin
include("test_oceananigans.jl")
end
Expand Down
1 change: 0 additions & 1 deletion test/setup_runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using Test
using Oceananigans
using PythonCall
using XESMF
61 changes: 41 additions & 20 deletions test/test_oceananigans.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
include("setup_runtests.jl")

using Oceananigans
using Oceananigans.Fields: AbstractField
using SparseArrays

x_node_array(x::AbstractVector, Nx, Ny) = view(x, 1:Nx) |> Array
y_node_array(x::AbstractVector, Nx, Ny) = view(x, 1:Ny) |> Array
x_node_array(x::AbstractMatrix, Nx, Ny) = view(x, 1:Nx, 1:Ny) |> Array
function x_node_array(x::AbstractVector, Nx, Ny)
return Array(repeat(view(x, 1:Nx), 1, Ny))'
end
function y_node_array(x::AbstractVector, Nx, Ny)
return Array(repeat(view(x, 1:Ny)', Nx, 1))'
end
x_node_array(x::AbstractMatrix, Nx, Ny) = Array(view(x, 1:Nx, 1:Ny))'

x_vertex_array(x::AbstractVector, Nx, Ny) = view(x, 1:Nx+1) |> Array
y_vertex_array(x::AbstractVector, Nx, Ny) = view(x, 1:Ny+1) |> Array
x_vertex_array(x::AbstractMatrix, Nx, Ny) = view(x, 1:Nx+1, 1:Ny+1) |> Array
function x_vertex_array(x::AbstractVector, Nx, Ny)
return Array(repeat(view(x, 1:Nx+1), 1, Ny+1))'
end
function y_vertex_array(x::AbstractVector, Nx, Ny)
return Array(repeat(view(x, 1:Ny+1)', Nx+1, 1))'
end
x_vertex_array(x::AbstractMatrix, Nx, Ny) = Array(view(x, 1:Nx+1, 1:Ny+1))'

y_node_array(x::AbstractMatrix, Nx, Ny) = x_node_array(x, Nx, Ny)
y_vertex_array(x::AbstractMatrix, Nx, Ny) = x_vertex_array(x, Nx, Ny)

function regridding_weights(dst_field, src_field)
function extract_xesmf_coordinates_structure(dst_field::AbstractField, src_field::AbstractField)
ℓx, ℓy, ℓz = Oceananigans.Fields.instantiated_location(src_field)

# We only support regridding between centered fields.
@assert ℓx isa Center
@assert ℓy isa Center
@assert (ℓx, ℓy, ℓz) == Oceananigans.Fields.instantiated_location(dst_field)

dst_grid = dst_field.grid
src_grid = src_field.grid

Expand All @@ -35,7 +41,7 @@ function regridding_weights(dst_field, src_field)
λvˢ = λnodes(src_grid, Face(), Face(), ℓz, with_halos=true)
φvˢ = φnodes(src_grid, Face(), Face(), ℓz, with_halos=true)

# Build data structures expected by XESMF.
# Build data structures expected by xESMF
Nˢx, Nˢy, Nˢz = size(src_field)
Nᵈx, Nᵈy, Nᵈz = size(dst_field)

Expand All @@ -59,7 +65,7 @@ function regridding_weights(dst_field, src_field)
"lat_b" => φvˢ,
"lon_b" => λvˢ)

return src_coordinates, dst_coordinates
return dst_coordinates, src_coordinates
end

@testset "Oceananigans Integration Tests" begin
Expand All @@ -72,7 +78,7 @@ end
cll = CenterField(ll)

# Test that we can create the coordinate structures
src_coordinates, dst_coordinates = regridding_weights(ctg, cll)
dst_coordinates, src_coordinates = extract_xesmf_coordinates_structure(cll, ctg)

# Verify coordinate structures are valid
@test haskey(src_coordinates, "lat")
Expand All @@ -89,12 +95,27 @@ end
@test size(tg) == (360, 170, 1)
@test size(ll) == (360, 180, 1)

xesmf = XESMF.xesmf
periodic = Oceananigans.Grids.topology(ctg.grid, 1) === Periodic ? PythonCall.pybuiltins.True : pybuiltins.False
periodic = Oceananigans.Grids.topology(ctg.grid, 1) === Periodic ? true : false
method = "conservative"
regridder = xesmf.Regridder(src_coordinates, dst_coordinates, method; periodic)
weights = XESMF.sparse_regridder_weights(regridder)
regridder = XESMF.Regridder(src_coordinates, dst_coordinates; method, periodic)

@test regridder.weights isa SparseMatrixCSC

# test that the regridder works with dense and strided arrays
dense_tg = zeros(prod(size(tg)))
dense_ll = zeros(prod(size(ll)))

strided_tg = vec(view(zeros(size(tg, 1)+5, size(tg, 2)+5), 1:size(tg, 1), 1:size(tg, 2)))
strided_ll = vec(view(zeros(size(ll, 1)+5, size(ll, 2)+5), 1:size(ll, 1), 1:size(ll, 2)))

rand_tg = rand(prod(size(tg)))
rand_ll = rand(prod(size(ll)))

dense_tg .= rand_tg
strided_tg .= rand_tg
regridder(dense_ll, dense_tg)
regridder(strided_ll, strided_tg)

@test weights isa SparseMatrixCSC
@test all(dense_ll .== strided_ll)
end
end