diff --git a/Project.toml b/Project.toml index 2b7d7b7f9..da4164895 100644 --- a/Project.toml +++ b/Project.toml @@ -11,8 +11,8 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +FastInterpolations = "9ea80cae-fc13-4c00-8066-6eaedb12f34b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa" @@ -38,8 +38,8 @@ Adapt = "4.4" ChunkSplitters = "3" DiffResults = "1" Distributed = "1" +FastInterpolations = "0.4.1" ForwardDiff = "1" -Interpolations = "0.15, 0.16" KernelAbstractions = "0.9" LinearAlgebra = "1" Meshes = "0.55, 0.56" diff --git a/src/TestParticle.jl b/src/TestParticle.jl index f5c1bed22..446ab712b 100644 --- a/src/TestParticle.jl +++ b/src/TestParticle.jl @@ -1,9 +1,8 @@ module TestParticle using LinearAlgebra: norm, ×, ⋅, diag, normalize -using Interpolations: interpolate, interpolate!, extrapolate, scale, BSpline, Linear, - Quadratic, Cubic, - Line, OnCell, Periodic, Flat, Gridded +using FastInterpolations: linear_interp, quadratic_interp, cubic_interp, constant_interp, + Extrap, NoExtrap, PeriodicBC, ZeroCurvBC, FillExtrap using SciMLBase: AbstractODEProblem, AbstractODEFunction, AbstractODESolution, ReturnCode, BasicEnsembleAlgorithm, EnsembleThreads, EnsembleSerial, EnsembleDistributed, DEFAULT_SPECIALIZATION, ODEFunction, ODEProblem, remake, @@ -46,7 +45,7 @@ export EnsembleSerial, EnsembleThreads, EnsembleDistributed, remake include("types.jl") include("utility/utility.jl") -include("utility/interpolation.jl") +include("utility/fastinterpolation.jl") include("sampler.jl") include("prepare.jl") include("gc.jl") diff --git a/src/prepare.jl b/src/prepare.jl index 01ff73475..4dacd209f 100644 --- a/src/prepare.jl +++ b/src/prepare.jl @@ -148,7 +148,7 @@ function prepare( return _prepare(E, B, F, x, y, z; gridtype = StructuredGrid, order, bc, kw...) end -function prepare(x::AbstractVector, E, B, F = ZeroField(); order = 1, bc = 3, dir = 1, kw...) +function prepare(x::AbstractVector, E, B, F = ZeroField(); order = 1, bc = 1, dir = 1, kw...) @assert issorted(x) "Grid vector `x` must be sorted." return _prepare(E, B, F, x; gridtype = CartesianGrid, order, bc, dir, kw...) end diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl new file mode 100644 index 000000000..e7e64fd03 --- /dev/null +++ b/src/utility/fastinterpolation.jl @@ -0,0 +1,365 @@ +# Field interpolations. + +""" + AbstractFieldInterpolator + +Abstract type for all field interpolators. +""" +abstract type AbstractFieldInterpolator <: Function end + +""" + FieldInterpolator{T} + +A callable struct that wraps a 3D interpolation object. +""" +struct FieldInterpolator{T} <: AbstractFieldInterpolator + itp::T +end + +const FieldInterpolator3D = FieldInterpolator + +@inbounds function (fi::FieldInterpolator)(xu) + return fi.itp((xu[1], xu[2], xu[3])) +end + +function (fi::FieldInterpolator)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::FieldInterpolator) = FieldInterpolator(Adapt.adapt(to, fi.itp)) + +""" + FieldInterpolator2D{T} + +A callable struct that wraps a 2D interpolation object. +""" +struct FieldInterpolator2D{T} <: AbstractFieldInterpolator + itp::T +end + +@inbounds function (fi::FieldInterpolator2D)(xu) + # 2D interpolation usually involves x and y + return fi.itp((xu[1], xu[2])) +end + +function (fi::FieldInterpolator2D)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::FieldInterpolator2D) = FieldInterpolator2D(Adapt.adapt(to, fi.itp)) + +""" + FieldInterpolator1D{T} + +A callable struct that wraps a 1D interpolation object. +""" +struct FieldInterpolator1D{T} <: AbstractFieldInterpolator + itp::T + dir::Int +end + +@inbounds function (fi::FieldInterpolator1D)(xu) + return fi.itp((xu[fi.dir],)) +end + +function (fi::FieldInterpolator1D)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::FieldInterpolator1D) = FieldInterpolator1D(Adapt.adapt(to, fi.itp), fi.dir) + +""" + SphericalFieldInterpolator{T} + +A callable struct for spherical grid interpolation (scalar or combined vector). +""" +struct SphericalFieldInterpolator{T} <: AbstractFieldInterpolator + itp::T +end + +function (fi::SphericalFieldInterpolator)(xu) + rθϕ = cart2sph(xu) + res = fi.itp(rθϕ) + if length(res) > 1 + # Convert vector result from spherical to cartesian basis + Br, Bθ, Bϕ = res + return sph_to_cart_vector(Br, Bθ, Bϕ, rθϕ[2], rθϕ[3]) + else + return res + end +end + +function (fi::SphericalFieldInterpolator)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp)) + +function _get_extrap_mode(bc, T::Type) + if bc == 2 + return Extrap(:wrap) + elseif bc == 3 + return Extrap(:clamp) + else + if T <: SVector + return FillExtrap(SVector{3, eltype(T)}(NaN, NaN, NaN)) + else + return FillExtrap(T(NaN)) + end + end +end + +function _fastinterp(grids, A, order, bc) + extrap_mode = _get_extrap_mode(bc, eltype(A)) + if order == 1 + return linear_interp(grids, A; extrap = extrap_mode) + elseif order == 2 + return quadratic_interp(grids, A; extrap = extrap_mode) + elseif order == 3 + return cubic_interp(grids, A; extrap = extrap_mode) + else + return constant_interp(grids, A; extrap = extrap_mode) + end +end + +@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) + +""" + build_interpolator(gridtype, A, grids..., order::Int=1, bc::Int=1) + build_interpolator(A, grids..., order::Int=1, bc::Int=1) + +Return a function for interpolating field array `A` on the given grids. + +# Arguments + + - `gridtype`: `CartesianGrid`, `RectilinearGrid` or `StructuredGrid`. Usually determined by the number of grids. + - `A`: field array. For vector field, the first dimension should be 3 if it's not an SVector wrapper. + - `order::Int=1`: order of interpolation in [1,2,3]. + - `bc::Int=1`: type of boundary conditions, 1 -> NaN, 2 -> periodic, 3 -> Clamp (flat extrapolation). + +# Notes +The input array `A` may be modified in-place for memory optimization. +""" +function build_interpolator( + ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 1 + ) where {T} + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + As = reinterpret(reshape, SVector{3, T}, A) + return build_interpolator(CartesianGrid, As, gridx, gridy, gridz, order, bc) +end + +function build_interpolator( + ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 1 + ) where {T} + if eltype(A) <: SVector + @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." + end + itp = _fastinterp((gridx, gridy, gridz), A, order, bc) + # Return field value at a given location. + return FieldInterpolator(itp) +end + +function build_interpolator( + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 1 + ) where {T} + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + As = reinterpret(reshape, SVector{3, T}, A) + return build_interpolator(RectilinearGrid, As, gridx, gridy, gridz, order, bc) +end + +function build_interpolator( + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 1 + ) where {T} + if eltype(A) <: SVector + @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." + end + if order != 1 + throw(ArgumentError("RectilinearGrid (CartesianNonUniform) only supports order=1 (Linear) interpolation.")) + end + + itp = _fastinterp((gridx, gridy, gridz), A, order, bc) + return FieldInterpolator(itp) +end + +function build_interpolator( + ::Type{<:StructuredGrid}, A::AbstractArray{T, 4}, + gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 + ) where {T} + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + As = reinterpret(reshape, SVector{3, T}, A) + return build_interpolator(StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) +end + +function build_interpolator( + ::Type{<:StructuredGrid}, A::AbstractArray{T, 3}, + gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 + ) where {T} + if eltype(A) <: SVector + @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." + end + # Detect if uniform grid (old Spherical) or non-uniform r (old SphericalNonUniformR) + # We check if gridr is an AbstractRange, e.g. Base.LogRange is an AbstractRange but not uniform! + is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) + + if is_uniform_r + itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) + else # Non-uniform R (SphericalNonUniformR behavior) + gridϕ, A = _ensure_full_phi(gridϕ, A) + itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) + end + + return SphericalFieldInterpolator(itp) +end + +function build_interpolator( + ::Type{<:CartesianGrid}, A, + gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 1 + ) + if eltype(A) <: SVector + @assert ndims(A) == 2 "Inconsistent 2D force field and grid! Expected 2D array of SVectors." + As = A + else + @assert size(A, 1) == 3 && ndims(A) == 3 "Inconsistent 2D force field and grid!" + As = reinterpret(reshape, SVector{3, eltype(A)}, A) + end + + itp = _fastinterp((gridx, gridy), As, order, bc) + return FieldInterpolator2D(itp) +end + +function build_interpolator( + ::Type{<:CartesianGrid}, A, gridx::AbstractVector, + order::Int = 1, bc::Int = 1; dir = 1 + ) + if eltype(A) <: SVector + @assert ndims(A) == 1 "Inconsistent 1D force field and grid! Expected 1D array of SVectors." + As = A + else + @assert size(A, 1) == 3 && ndims(A) == 2 "Inconsistent 1D force field and grid!" + As = reinterpret(reshape, SVector{3, eltype(A)}, A) + end + + itp = _fastinterp((gridx,), As, order, bc) + + return FieldInterpolator1D(itp, dir) +end + +function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} + min_phi, max_phi = extrema(gridϕ) + needs_0 = !isapprox(min_phi, 0, atol = 1.0e-5) + needs_2pi = !isapprox(max_phi, 2π, atol = 1.0e-5) + + if !needs_0 && !needs_2pi + return gridϕ, A + end + + new_grid_vec = collect(gridϕ) + if needs_0 + pushfirst!(new_grid_vec, 0.0) + end + if needs_2pi + push!(new_grid_vec, 2π) + end + + phi_dim = N + new_A = Array{T, N}(undef, (size(A)[1:(end - 1)]..., length(new_grid_vec))) + + start_idx = needs_0 ? 2 : 1 + end_idx = start_idx + length(gridϕ) - 1 + + selectdim(new_A, phi_dim, start_idx:end_idx) .= A + + if needs_0 + src_idx_for_0 = size(A, phi_dim) + selectdim(new_A, phi_dim, 1) .= selectdim(A, phi_dim, src_idx_for_0) + end + + if needs_2pi + if needs_0 + selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(new_A, phi_dim, 1) + else # needs 2π only + selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(A, phi_dim, 1) + end + end + + return new_grid_vec, new_A +end + +# Time-dependent field interpolation. + +""" + LazyTimeInterpolator{T, F, L} + +A callable struct for handling time-dependent fields with lazy loading and linear time interpolation. + +# Fields + +- `times::Vector{T}`: Sorted vector of time points. +- `loader::L`: Function `i -> field` that loads the field at index `i`. +- `buffer::Dict{Int, F}`: Cache for loaded fields. +- `lock::ReentrantLock`: Lock for thread safety. +""" +struct LazyTimeInterpolator{T, F, L} <: Function + times::Vector{T} + loader::L + buffer::Dict{Int, F} + lock::ReentrantLock +end + +function LazyTimeInterpolator(times::AbstractVector, loader::Function) + # Determine the field type by loading the first field + f1 = loader(1) + return _LazyTimeInterpolator(times, loader, f1) +end + +function _LazyTimeInterpolator(times::AbstractVector, loader::Function, f1::F) where {F} + buffer = Dict{Int, F}(1 => f1) + lock = ReentrantLock() + return LazyTimeInterpolator{eltype(times), F, typeof(loader)}( + times, loader, buffer, lock + ) +end + +function (itp::LazyTimeInterpolator)(x, t) + # Find the time interval [t1, t2] such that t1 <= t <= t2 (assume times is sorted) + idx = searchsortedlast(itp.times, t) + + # Handle out-of-bounds + if idx == 0 + return _get_field!(itp, 1)(x) # clamp to start + elseif idx >= length(itp.times) + return _get_field!(itp, length(itp.times))(x) # clamp to end + end + + t1 = itp.times[idx] + t2 = itp.times[idx + 1] + + w = (t - t1) / (t2 - t1) # linear weights + + # Load fields (lazily) + f1 = _get_field!(itp, idx) + f2 = _get_field!(itp, idx + 1) + + return (1 - w) * f1(x) + w * f2(x) +end + +function _get_field!(itp::LazyTimeInterpolator, idx::Int) + return lock(itp.lock) do + if !haskey(itp.buffer, idx) + # Remove far-away indices + filter!(p -> abs(p.first - idx) <= 1, itp.buffer) + + field = itp.loader(idx) + itp.buffer[idx] = field + end + return itp.buffer[idx] + end +end diff --git a/src/utility/interpolation.jl b/src/utility/interpolation.jl index 0cf81f2ba..e9ad141a0 100644 --- a/src/utility/interpolation.jl +++ b/src/utility/interpolation.jl @@ -1,7 +1,5 @@ # Field interpolations. -@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) - """ AbstractFieldInterpolator @@ -20,7 +18,7 @@ end const FieldInterpolator3D = FieldInterpolator -function (fi::FieldInterpolator)(xu) +@inbounds function (fi::FieldInterpolator)(xu) return fi.itp(xu[1], xu[2], xu[3]) end @@ -39,7 +37,7 @@ struct FieldInterpolator2D{T} <: AbstractFieldInterpolator itp::T end -function (fi::FieldInterpolator2D)(xu) +@inbounds function (fi::FieldInterpolator2D)(xu) # 2D interpolation usually involves x and y return fi.itp(xu[1], xu[2]) end @@ -60,7 +58,7 @@ struct FieldInterpolator1D{T} <: AbstractFieldInterpolator dir::Int end -function (fi::FieldInterpolator1D)(xu) +@inbounds function (fi::FieldInterpolator1D)(xu) return fi.itp(xu[fi.dir]) end @@ -97,6 +95,8 @@ end Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp)) +@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) + """ build_interpolator(gridtype, A, grids..., order::Int=1, bc::Int=1) build_interpolator(A, grids..., order::Int=1, bc::Int=1) @@ -115,7 +115,8 @@ The input array `A` may be modified in-place for memory optimization. """ function build_interpolator( ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 1 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) @@ -124,7 +125,8 @@ end function build_interpolator( ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 1 ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." @@ -138,7 +140,8 @@ end function build_interpolator( ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 1 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) @@ -147,7 +150,8 @@ end function build_interpolator( ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 1 ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." @@ -175,7 +179,7 @@ end function build_interpolator( ::Type{<:StructuredGrid}, A::AbstractArray{T, 4}, - gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 + gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) @@ -184,28 +188,59 @@ end function build_interpolator( ::Type{<:StructuredGrid}, A::AbstractArray{T, 3}, - gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 + gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." end - # Detect if uniform grid (old Spherical) or non-uniform r (old SphericalNonUniformR) - # We check if gridr is an AbstractRange, e.g. Base.LogRange is an AbstractRange but not uniform! - is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) - - if is_uniform_r - itp_unscaled = _get_interp_object(StructuredGrid, A, order, bc) - itp = scale(itp_unscaled, gridr, gridθ, gridϕ) - else # Non-uniform R (SphericalNonUniformR behavior) - bctype = (Flat(), Flat(), Periodic()) - gridϕ, A = _ensure_full_phi(gridϕ, A) + r_min, r_max = extrema(gridr) + θ_min, θ_max = extrema(gridθ) + ϕ_min, ϕ_max = extrema(gridϕ) + + @assert r_min >= 0 "r must be non-negative." + @assert θ_min >= 0 && θ_max <= π "θ must be within [0, π]." + @assert ϕ_min >= 0 && ϕ_max <= 2π "ϕ must be within [0, 2π]." + + has_0 = isapprox(ϕ_min, 0, atol = 1.0e-5) + has_2pi = isapprox(ϕ_max, 2π, atol = 1.0e-5) + + ϕ_bc = if has_0 && has_2pi + Periodic(OnGrid()) + else + Periodic(OnCell()) + end + + bctype = (Flat(), Flat(), ϕ_bc) + + if order == 1 itp = extrapolate(interpolate!((gridr, gridθ, gridϕ), A, Gridded(Linear())), bctype) + else + interp_type = if order == 2 + Quadratic + elseif order == 3 + Cubic + else + throw(ArgumentError("Unsupported interpolation order!")) + end + itp_type = ( + BSpline(interp_type(Flat(OnCell()))), + BSpline(interp_type(Flat(OnCell()))), + BSpline(interp_type(ϕ_bc)), + ) + itp_obj = extrapolate(interpolate(A, itp_type), bctype) + itp = scale(itp_obj, gridr, gridθ, gridϕ) end + bctype = (Flat(), Flat(), phi_bc) + itp = extrapolate(interpolate!((gridr, gridθ, gridϕ), A, Gridded(Linear())), bctype) + return SphericalFieldInterpolator(itp) end -function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 2) +function build_interpolator( + ::Type{<:CartesianGrid}, A, + gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 1 + ) if eltype(A) <: SVector @assert ndims(A) == 2 "Inconsistent 2D force field and grid! Expected 2D array of SVectors." As = A @@ -220,7 +255,10 @@ function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, g return FieldInterpolator2D(interp) end -function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, order::Int = 1, bc::Int = 3; dir = 1) +function build_interpolator( + ::Type{<:CartesianGrid}, A, gridx::AbstractVector, + order::Int = 1, bc::Int = 1; dir = 1 + ) if eltype(A) <: SVector @assert ndims(A) == 1 "Inconsistent 1D force field and grid! Expected 1D array of SVectors." As = A @@ -280,64 +318,6 @@ function _get_interp_object(A, order::Int, bc::Int) return extrapolate(interpolate(A, bspline), bctype) end -function _get_interp_object(::Type{<:StructuredGrid}, A, order::Int, bc::Int) - bspline_r = _get_bspline(order, false) - bspline_θ = _get_bspline(order, false) - bspline_ϕ = _get_bspline(order, true) - - itp_type = (bspline_r, bspline_θ, bspline_ϕ) - - bctype = if eltype(A) <: SVector - SVector{3, eltype(eltype(A))}(NaN, NaN, NaN) - else - eltype(A)(NaN) - end - - return extrapolate(interpolate(A, itp_type), bctype) -end - -function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} - min_phi, max_phi = extrema(gridϕ) - needs_0 = !isapprox(min_phi, 0, atol = 1.0e-5) - needs_2pi = !isapprox(max_phi, 2π, atol = 1.0e-5) - - if !needs_0 && !needs_2pi - return gridϕ, A - end - - new_grid_vec = collect(gridϕ) - if needs_0 - pushfirst!(new_grid_vec, 0.0) - end - if needs_2pi - push!(new_grid_vec, 2π) - end - - phi_dim = N - new_A = Array{T, N}(undef, (size(A)[1:(end - 1)]..., length(new_grid_vec))) - - start_idx = needs_0 ? 2 : 1 - end_idx = start_idx + length(gridϕ) - 1 - - selectdim(new_A, phi_dim, start_idx:end_idx) .= A - - if needs_0 - src_idx_for_0 = size(A, phi_dim) - selectdim(new_A, phi_dim, 1) .= selectdim(A, phi_dim, src_idx_for_0) - end - - if needs_2pi - if needs_0 - selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(new_A, phi_dim, 1) - else # needs 2π only - selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim( - A, phi_dim, 1 - ) - end - end - - return new_grid_vec, new_A -end # Time-dependent field interpolation. diff --git a/test/runtests.jl b/test/runtests.jl index 7d11c2f3e..65a68ecff 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -170,7 +170,7 @@ end r, θ, ϕ, B_sph = setup_spherical_field() param = prepare(r, θ, ϕ, zero_E, B_sph; gridtype = StructuredGrid) # Check field interpolation Bz - @test param[4](SA[1.0, 1.0, 1.0])[3] == 9.888387888463716e-9 + @test param[4](SA[1.0, 1.0, 1.0])[3] ≈ 9.888387888463716e-9 # Test Vector inputs for spherical grid r_vec = collect(r) @@ -181,7 +181,7 @@ end r_vec, θ_vec, ϕ_vec, zero_E, B_sph; species = Ion(m = 16, q = 1), gridtype = TP.StructuredGrid ) - @test param_vec[4](SA[1.0, 1.0, 1.0])[3] == 9.888387888463716e-9 + @test param_vec[4](SA[1.0, 1.0, 1.0])[3] ≈ 9.888387888463716e-9 end @testset "analytical field" begin diff --git a/test/test_utility.jl b/test/test_utility.jl index 5c3bba260..593d02f20 100644 --- a/test/test_utility.jl +++ b/test/test_utility.jl @@ -50,27 +50,30 @@ import TestParticle as TP for i in eachindex(x), j in eachindex(y), k in eachindex(z) ] nfunc11 = TP.build_interpolator(n, x, y, z) - @test nfunc11(SA[9, 0, 0]) == 11.85 + @test isnan(nfunc11(SA[20, 0, 0])) + @test nfunc11(SA[9, 0, 0]) ≈ 11.85f0 nfunc12 = TP.build_interpolator(n, x, y, z, 1, 2) - @test nfunc12(SA[20, 0, 0]) == 9.5 + @test nfunc12(SA[20, 0, 0]) ≈ 10.5f0 nfunc13 = TP.build_interpolator(n, x, y, z, 1, 3) - @test nfunc13(SA[20, 0, 0]) == 12.0 + @test nfunc13(SA[20, 0, 0]) ≈ 12.0f0 nfunc21 = TP.build_interpolator(n, x, y, z, 2) - @test nfunc21(SA[9, 0, 0]) == 11.898528302013874 + @test isnan(nfunc21(SA[20, 0, 0])) + @test nfunc21(SA[9, 0, 0]) ≈ 11.85f0 nfunc22 = TP.build_interpolator(n, x, y, z, 2, 2) - @test nfunc22(SA[20, 0, 0]) == 9.166666686534882 + @test nfunc22(SA[20, 0, 0]) ≈ 10.5f0 nfunc23 = TP.build_interpolator(n, x, y, z, 2, 3) - @test nfunc23(SA[20, 0, 0]) == 12.14705765247345 + @test nfunc23(SA[20, 0, 0]) ≈ 12.0f0 nfunc31 = TP.build_interpolator(n, x, y, z, 3) - @test nfunc31(SA[9, 0, 0]) == 11.882999392215163 + @test isnan(nfunc31(SA[20, 0, 0])) + @test nfunc31(SA[9, 0, 0]) ≈ 11.85f0 nfunc32 = TP.build_interpolator(n, x, y, z, 3, 2) - @test nfunc32(SA[20, 0, 0]) == 9.124999547351358 + @test nfunc32(SA[20, 0, 0]) ≈ 10.5f0 nfunc33 = TP.build_interpolator(n, x, y, z, 3, 3) - @test nfunc33(SA[20, 0, 0]) == 12.191176189381315 + @test nfunc33(SA[20, 0, 0]) ≈ 12.0f0 end begin # spherical interpolation - r = range(0, 10, length = 11) + r = range(0.1, 10, length = 11) θ = range(0, π, length = 11) ϕ = range(0, 2π, length = 11) # Vector field @@ -82,7 +85,18 @@ import TestParticle as TP A = ones(length(r), length(θ), length(ϕ)) Afunc = TP.build_interpolator(TP.StructuredGrid, A, r, θ, ϕ) @test Afunc(SA[1, 1, 1]) == 1.0 - @test Afunc(SA[0, 0, 0]) == 1.0 + @test isnan(Afunc(SA[0, 0, 0])) # outside domain -> NaN with bc=1 + + # High order spherical interpolation + Bfunc2 = TP.build_interpolator(TP.StructuredGrid, B, r, θ, ϕ, 2) + @test Bfunc2(SA[1, 1, 1]) ≈ [0.57735, 0.57735, 0.57735] atol = 1.0e-5 + Afunc2 = TP.build_interpolator(TP.StructuredGrid, A, r, θ, ϕ, 2) + @test Afunc2(SA[1, 1, 1]) ≈ 1.0 + + Bfunc3 = TP.build_interpolator(TP.StructuredGrid, B, r, θ, ϕ, 3) + @test Bfunc3(SA[1, 1, 1]) ≈ [0.57735, 0.57735, 0.57735] atol = 1.0e-5 + Afunc3 = TP.build_interpolator(TP.StructuredGrid, A, r, θ, ϕ, 3) + @test Afunc3(SA[1, 1, 1]) ≈ 1.0 end begin # non-uniform spherical interpolation @@ -98,7 +112,7 @@ import TestParticle as TP A = ones(length(r), length(θ), length(ϕ)) Afunc = TP.build_interpolator(TP.StructuredGrid, A, r, θ, ϕ) @test Afunc(SA[1, 1, 1]) == 1.0 - @test Afunc(SA[0, 0, 0]) == 1.0 + @test isnan(Afunc(SA[0, 0, 0])) # outside domain -> NaN with bc=1 end end