From 2f1a7dbea7f250bfa377fc3fb20a9c13a7ca7309 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Sun, 1 Mar 2026 10:24:15 -0500 Subject: [PATCH 01/27] refactor: substitute Interpolations.jl with FastInterpolations.jl --- Project.toml | 4 +- src/TestParticle.jl | 5 +- src/utility/fastinterpolation.jl | 434 +++++++++++++++++++++++++++++++ 3 files changed, 437 insertions(+), 6 deletions(-) create mode 100644 src/utility/fastinterpolation.jl diff --git a/Project.toml b/Project.toml index 2b7d7b7f9..4c535fdb7 100644 --- a/Project.toml +++ b/Project.toml @@ -11,8 +11,8 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +FastInterpolations = "9ea80cae-fc13-4c00-8066-6eaedb12f34b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa" @@ -38,8 +38,8 @@ Adapt = "4.4" ChunkSplitters = "3" DiffResults = "1" Distributed = "1" +FastInterpolations = "0.3.0" ForwardDiff = "1" -Interpolations = "0.15, 0.16" KernelAbstractions = "0.9" LinearAlgebra = "1" Meshes = "0.55, 0.56" diff --git a/src/TestParticle.jl b/src/TestParticle.jl index f5c1bed22..9ca9970ba 100644 --- a/src/TestParticle.jl +++ b/src/TestParticle.jl @@ -1,9 +1,6 @@ module TestParticle using LinearAlgebra: norm, ×, ⋅, diag, normalize -using Interpolations: interpolate, interpolate!, extrapolate, scale, BSpline, Linear, - Quadratic, Cubic, - Line, OnCell, Periodic, Flat, Gridded using SciMLBase: AbstractODEProblem, AbstractODEFunction, AbstractODESolution, ReturnCode, BasicEnsembleAlgorithm, EnsembleThreads, EnsembleSerial, EnsembleDistributed, DEFAULT_SPECIALIZATION, ODEFunction, ODEProblem, remake, @@ -46,7 +43,7 @@ export EnsembleSerial, EnsembleThreads, EnsembleDistributed, remake include("types.jl") include("utility/utility.jl") -include("utility/interpolation.jl") +include("utility/fastinterpolation.jl") include("sampler.jl") include("prepare.jl") include("gc.jl") diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl new file mode 100644 index 000000000..5f642a046 --- /dev/null +++ b/src/utility/fastinterpolation.jl @@ -0,0 +1,434 @@ +# Field interpolations using FastInterpolations.jl. + +using FastInterpolations +using StaticArrays +using Adapt + +import FastInterpolations: _value_type +# Promote rule to construct the interpolant object natively for SVector. +_value_type(::Type{SVector{N, T}}, ::Type{Tg}) where {N, T, Tg <: AbstractFloat} = SVector{N, promote_type(T, Tg)} + +@inline getinterp(A, grid1, args...) = getinterp(CartesianGrid, A, grid1, args...) + +""" + AbstractFieldInterpolator + +Abstract type for all field interpolators. +""" +abstract type AbstractFieldInterpolator <: Function end + +""" + FieldInterpolator{T, G} + +A callable struct that wraps a 3D interpolation object and its grid. +""" +struct FieldInterpolator{T, G} <: AbstractFieldInterpolator + itp::T + grid::G + bc::Int +end + +_in_bounds(x, gridx) = first(gridx) <= x <= last(gridx) + +function (fi::FieldInterpolator)(xu) + if fi.bc == 1 + if !(_in_bounds(xu[1], fi.grid[1]) && _in_bounds(xu[2], fi.grid[2]) && _in_bounds(xu[3], fi.grid[3])) + T_val = typeof(fi.itp((xu[1], xu[2], xu[3]))) + if T_val <: SVector + return fill(eltype(T_val)(NaN), length(T_val)) + else + return T_val(NaN) + end + end + end + return fi.itp((xu[1], xu[2], xu[3])) +end + +function (fi::FieldInterpolator)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::FieldInterpolator) = FieldInterpolator(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) + +""" + FieldInterpolator2D{T, G} + +A callable struct that wraps a 2D interpolation object. +""" +struct FieldInterpolator2D{T, G} <: AbstractFieldInterpolator + itp::T + grid::G + bc::Int +end + +function (fi::FieldInterpolator2D)(xu) + if fi.bc == 1 + if !(_in_bounds(xu[1], fi.grid[1]) && _in_bounds(xu[2], fi.grid[2])) + T_val = typeof(fi.itp((xu[1], xu[2]))) + if T_val <: SVector + return fill(eltype(T_val)(NaN), length(T_val)) + else + return T_val(NaN) + end + end + end + return fi.itp((xu[1], xu[2])) +end + +function (fi::FieldInterpolator2D)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::FieldInterpolator2D) = FieldInterpolator2D(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) + +""" + FieldInterpolator1D{T, G} + +A callable struct that wraps a 1D interpolation object. +""" +struct FieldInterpolator1D{T, G} <: AbstractFieldInterpolator + itp::T + grid::G + bc::Int + dir::Int +end + +function (fi::FieldInterpolator1D)(xu) + if fi.bc == 1 + if !_in_bounds(xu[fi.dir], fi.grid) + T_val = typeof(fi.itp((xu[fi.dir],))) + if T_val <: SVector + return fill(eltype(T_val)(NaN), length(T_val)) + else + return T_val(NaN) + end + end + end + return fi.itp((xu[fi.dir],)) +end + +function (fi::FieldInterpolator1D)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::FieldInterpolator1D) = FieldInterpolator1D(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc, fi.dir) + +""" + SphericalFieldInterpolator{T, G} + +A callable struct for spherical grid interpolation. +""" +struct SphericalFieldInterpolator{T, G} <: AbstractFieldInterpolator + itp::T + grid::G + bc::Int +end + +function (fi::SphericalFieldInterpolator)(xu) + r_val, θ_val, ϕ_val = cart2sph(xu) + + if fi.bc == 1 + if !(_in_bounds(r_val, fi.grid[1]) && _in_bounds(θ_val, fi.grid[2]) && _in_bounds(ϕ_val, fi.grid[3])) + res_val = fi.itp((r_val, θ_val, ϕ_val)) + if typeof(res_val) <: SVector || length(res_val) > 1 + return fill(eltype(res_val)(NaN), length(res_val)) + else + return typeof(res_val)(NaN) + end + end + end + + res = fi.itp((r_val, θ_val, ϕ_val)) + if typeof(res) <: SVector || length(res) > 1 + Br, Bθ, Bϕ = res + return sph_to_cart_vector(Br, Bθ, Bϕ, θ_val, ϕ_val) + else + return res + end +end + +function (fi::SphericalFieldInterpolator)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) + +function getinterp_scalar(A, grid1, grid2, grid3, args...) + return getinterp_scalar(CartesianGrid, A, grid1, grid2, grid3, args...) +end + +function _get_extrap_mode(bc) + if bc == 2 + return Extrap(:wrap) + elseif bc == 3 + return Extrap(:constant) + else + # For bc == 1, we handle NaN in wrapper and use NoExtrap() for inner to avoid errors + return NoExtrap() + end +end + +struct ComponentInterpolator{T1, T2, T3} + itp1::T1 + itp2::T2 + itp3::T3 +end + +function (ci::ComponentInterpolator)(x) + return SA[ci.itp1(x), ci.itp2(x), ci.itp3(x)] +end + +function _fastinterp(grids, A, order, bc) + extrap_mode = _get_extrap_mode(bc) + if order == 1 + return linear_interp(grids, A; extrap = extrap_mode) + elseif order == 2 + return quadratic_interp(grids, A; extrap = extrap_mode) + elseif order == 3 + if eltype(A) <: SVector{3} + A1 = [v[1] for v in A] + A2 = [v[2] for v in A] + A3 = [v[3] for v in A] + itp1 = cubic_interp(grids, A1; extrap = extrap_mode) + itp2 = cubic_interp(grids, A2; extrap = extrap_mode) + itp3 = cubic_interp(grids, A3; extrap = extrap_mode) + return ComponentInterpolator(itp1, itp2, itp3) + end + return cubic_interp(grids, A; extrap = extrap_mode) + else + return constant_interp(grids, A; extrap = extrap_mode) + end +end + +function getinterp( + ::Type{<:CartesianGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) + if eltype(A) <: SVector + @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." + else + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + end + return get_interpolator(CartesianGrid, A, gridx, gridy, gridz, order, bc) +end + +function getinterp( + ::Type{<:RectilinearGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) + if eltype(A) <: SVector + @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." + else + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + end + return get_interpolator(RectilinearGrid, A, gridx, gridy, gridz, order, bc) +end + +function getinterp( + ::Type{<:StructuredGrid}, A, gridr, gridθ, gridϕ, + order::Int = 1, bc::Int = 3 + ) + if eltype(A) <: SVector + @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." + else + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + end + return get_interpolator(StructuredGrid, A, gridr, gridθ, gridϕ, order, bc) +end + +function getinterp(::Type{<:CartesianGrid}, A, gridx, gridy, order::Int = 1, bc::Int = 2) + if eltype(A) <: SVector + @assert ndims(A) == 2 "Inconsistent 2D force field and grid! Expected 2D array of SVectors." + As = A + else + @assert size(A, 1) == 3 && ndims(A) == 3 "Inconsistent 2D force field and grid!" + As = reinterpret(reshape, SVector{3, eltype(A)}, A) + end + + itp = _fastinterp((gridx, gridy), As, order, bc) + return FieldInterpolator2D(itp, (gridx, gridy), bc) +end + +function get_interpolator( + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, + gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) where {T} + As = reinterpret(reshape, SVector{3, T}, A) + return get_interpolator(RectilinearGrid, As, gridx, gridy, gridz, order, bc) +end + +function get_interpolator( + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, + gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) where {T} + if order != 1 + throw(ArgumentError("RectilinearGrid (CartesianNonUniform) only supports order=1 (Linear) interpolation.")) + end + + itp = _fastinterp((gridx, gridy, gridz), A, order, bc) + return FieldInterpolator(itp, (gridx, gridy, gridz), bc) +end + +function getinterp(::Type{<:CartesianGrid}, A, gridx, order::Int = 1, bc::Int = 3; dir = 1) + if eltype(A) <: SVector + @assert ndims(A) == 1 "Inconsistent 1D force field and grid! Expected 1D array of SVectors." + As = A + else + @assert size(A, 1) == 3 && ndims(A) == 2 "Inconsistent 1D force field and grid!" + As = reinterpret(reshape, SVector{3, eltype(A)}, A) + end + + itp = _fastinterp((gridx,), As, order, bc) + return FieldInterpolator1D(itp, gridx, bc, dir) +end + +function getinterp_scalar( + ::Type{<:CartesianGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) + return get_interpolator(CartesianGrid, A, gridx, gridy, gridz, order, bc) +end + +function getinterp_scalar( + ::Type{<:RectilinearGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) + return get_interpolator(RectilinearGrid, A, gridx, gridy, gridz, order, bc) +end + +function getinterp_scalar( + ::Type{<:StructuredGrid}, A, + gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 + ) + return get_interpolator(StructuredGrid, A, gridr, gridθ, gridϕ, order, bc) +end + +function get_interpolator( + ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, + gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) where {T} + As = reinterpret(reshape, SVector{3, T}, A) + return get_interpolator(CartesianGrid, As, gridx, gridy, gridz, order, bc) +end + +function get_interpolator( + ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, + gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) where {T} + itp = _fastinterp((gridx, gridy, gridz), A, order, bc) + return FieldInterpolator(itp, (gridx, gridy, gridz), bc) +end + +function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} + min_phi, max_phi = extrema(gridϕ) + needs_0 = !isapprox(min_phi, 0, atol = 1.0e-5) + needs_2pi = !isapprox(max_phi, 2π, atol = 1.0e-5) + + if !needs_0 && !needs_2pi + return gridϕ, A + end + + new_grid_vec = collect(gridϕ) + if needs_0 + pushfirst!(new_grid_vec, 0.0) + end + if needs_2pi + push!(new_grid_vec, 2π) + end + + phi_dim = N + new_A = Array{T, N}(undef, (size(A)[1:(end - 1)]..., length(new_grid_vec))) + + start_idx = needs_0 ? 2 : 1 + end_idx = start_idx + length(gridϕ) - 1 + + selectdim(new_A, phi_dim, start_idx:end_idx) .= A + + if needs_0 + src_idx_for_0 = size(A, phi_dim) + selectdim(new_A, phi_dim, 1) .= selectdim(A, phi_dim, src_idx_for_0) + end + + if needs_2pi + if needs_0 + selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(new_A, phi_dim, 1) + else + selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(A, phi_dim, 1) + end + end + + return new_grid_vec, new_A +end + +function get_interpolator( + ::Type{<:StructuredGrid}, + A::AbstractArray{T, 4}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 + ) where {T} + As = reinterpret(reshape, SVector{3, T}, A) + return get_interpolator(StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) +end + +function get_interpolator( + ::Type{<:StructuredGrid}, + A::AbstractArray{T, 3}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 + ) where {T} + is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) + + if is_uniform_r + itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) + else + gridϕ, A = _ensure_full_phi(gridϕ, A) + itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) + end + + return SphericalFieldInterpolator(itp, (gridr, gridθ, gridϕ), bc) +end + +# Time-dependent field interpolation. + +struct LazyTimeInterpolator{T, F, L} <: Function + times::Vector{T} + loader::L + buffer::Dict{Int, F} + lock::ReentrantLock +end + +function LazyTimeInterpolator(times::AbstractVector, loader::Function) + f1 = loader(1) + return _LazyTimeInterpolator(times, loader, f1) +end + +function _LazyTimeInterpolator(times::AbstractVector, loader::Function, f1::F) where {F} + buffer = Dict{Int, F}(1 => f1) + lock = ReentrantLock() + return LazyTimeInterpolator{eltype(times), F, typeof(loader)}( + times, loader, buffer, lock + ) +end + +function (itp::LazyTimeInterpolator)(x, t) + idx = searchsortedlast(itp.times, t) + + if idx == 0 + return _get_field!(itp, 1)(x) + elseif idx >= length(itp.times) + return _get_field!(itp, length(itp.times))(x) + end + + t1 = itp.times[idx] + t2 = itp.times[idx + 1] + + w = (t - t1) / (t2 - t1) + + f1 = _get_field!(itp, idx) + f2 = _get_field!(itp, idx + 1) + + return (1 - w) * f1(x) + w * f2(x) +end + +function _get_field!(itp::LazyTimeInterpolator, idx::Int) + return lock(itp.lock) do + if !haskey(itp.buffer, idx) + filter!(p -> abs(p.first - idx) <= 1, itp.buffer) + field = itp.loader(idx) + itp.buffer[idx] = field + end + return itp.buffer[idx] + end +end From d90e6c3fdd66ae25d95fc31068c91c05fd813296 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Sun, 1 Mar 2026 22:50:15 -0500 Subject: [PATCH 02/27] Try to fix the Float32 vs Float64 bug --- src/utility/fastinterpolation.jl | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl index 5f642a046..a300e1bd3 100644 --- a/src/utility/fastinterpolation.jl +++ b/src/utility/fastinterpolation.jl @@ -178,25 +178,40 @@ function (ci::ComponentInterpolator)(x) return SA[ci.itp1(x), ci.itp2(x), ci.itp3(x)] end +function _match_grid_type(g::AbstractRange, ::Type{T}) where {T <: AbstractFloat} + return range(T(first(g)), T(last(g)), length = length(g)) +end +function _match_grid_type(g::AbstractVector, ::Type{T}) where {T <: AbstractFloat} + return T.(g) +end +function _match_grid_type(g::Tuple, ::Type{T}) where {T <: AbstractFloat} + return T.(g) +end + function _fastinterp(grids, A, order, bc) + T_A = eltype(A) + T_F = T_A <: SVector ? eltype(T_A) : T_A + T_F = T_F <: AbstractFloat ? T_F : Float64 + matched_grids = map(g -> _match_grid_type(g, T_F), grids) + extrap_mode = _get_extrap_mode(bc) if order == 1 - return linear_interp(grids, A; extrap = extrap_mode) + return linear_interp(matched_grids, A; extrap = extrap_mode) elseif order == 2 - return quadratic_interp(grids, A; extrap = extrap_mode) + return quadratic_interp(matched_grids, A; extrap = extrap_mode) elseif order == 3 if eltype(A) <: SVector{3} A1 = [v[1] for v in A] A2 = [v[2] for v in A] A3 = [v[3] for v in A] - itp1 = cubic_interp(grids, A1; extrap = extrap_mode) - itp2 = cubic_interp(grids, A2; extrap = extrap_mode) - itp3 = cubic_interp(grids, A3; extrap = extrap_mode) + itp1 = cubic_interp(matched_grids, A1; extrap = extrap_mode) + itp2 = cubic_interp(matched_grids, A2; extrap = extrap_mode) + itp3 = cubic_interp(matched_grids, A3; extrap = extrap_mode) return ComponentInterpolator(itp1, itp2, itp3) end - return cubic_interp(grids, A; extrap = extrap_mode) + return cubic_interp(matched_grids, A; extrap = extrap_mode) else - return constant_interp(grids, A; extrap = extrap_mode) + return constant_interp(matched_grids, A; extrap = extrap_mode) end end From c01cc53b92fa84513871476f021cbb6c4d02bbc0 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Sun, 1 Mar 2026 23:12:19 -0500 Subject: [PATCH 03/27] fix out-of-bound return type --- src/utility/fastinterpolation.jl | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl index a300e1bd3..1f8429d25 100644 --- a/src/utility/fastinterpolation.jl +++ b/src/utility/fastinterpolation.jl @@ -8,6 +8,9 @@ import FastInterpolations: _value_type # Promote rule to construct the interpolant object natively for SVector. _value_type(::Type{SVector{N, T}}, ::Type{Tg}) where {N, T, Tg <: AbstractFloat} = SVector{N, promote_type(T, Tg)} +Base.eltype(::FastInterpolations.AbstractInterpolant{Tx, Ty}) where {Tx, Ty} = Ty +Base.eltype(::FastInterpolations.AbstractInterpolantND{Tx, Ty, N}) where {Tx, Ty, N} = Ty + @inline getinterp(A, grid1, args...) = getinterp(CartesianGrid, A, grid1, args...) """ @@ -33,9 +36,9 @@ _in_bounds(x, gridx) = first(gridx) <= x <= last(gridx) function (fi::FieldInterpolator)(xu) if fi.bc == 1 if !(_in_bounds(xu[1], fi.grid[1]) && _in_bounds(xu[2], fi.grid[2]) && _in_bounds(xu[3], fi.grid[3])) - T_val = typeof(fi.itp((xu[1], xu[2], xu[3]))) + T_val = eltype(fi.itp) if T_val <: SVector - return fill(eltype(T_val)(NaN), length(T_val)) + return T_val(ntuple(_ -> NaN, Val(length(T_val)))) else return T_val(NaN) end @@ -64,9 +67,9 @@ end function (fi::FieldInterpolator2D)(xu) if fi.bc == 1 if !(_in_bounds(xu[1], fi.grid[1]) && _in_bounds(xu[2], fi.grid[2])) - T_val = typeof(fi.itp((xu[1], xu[2]))) + T_val = eltype(fi.itp) if T_val <: SVector - return fill(eltype(T_val)(NaN), length(T_val)) + return T_val(ntuple(_ -> NaN, Val(length(T_val)))) else return T_val(NaN) end @@ -96,9 +99,9 @@ end function (fi::FieldInterpolator1D)(xu) if fi.bc == 1 if !_in_bounds(xu[fi.dir], fi.grid) - T_val = typeof(fi.itp((xu[fi.dir],))) + T_val = eltype(fi.itp) if T_val <: SVector - return fill(eltype(T_val)(NaN), length(T_val)) + return T_val(ntuple(_ -> NaN, Val(length(T_val)))) else return T_val(NaN) end @@ -129,11 +132,11 @@ function (fi::SphericalFieldInterpolator)(xu) if fi.bc == 1 if !(_in_bounds(r_val, fi.grid[1]) && _in_bounds(θ_val, fi.grid[2]) && _in_bounds(ϕ_val, fi.grid[3])) - res_val = fi.itp((r_val, θ_val, ϕ_val)) - if typeof(res_val) <: SVector || length(res_val) > 1 - return fill(eltype(res_val)(NaN), length(res_val)) + T_val = eltype(fi.itp) + if T_val <: SVector + return T_val(ntuple(_ -> NaN, Val(length(T_val)))) else - return typeof(res_val)(NaN) + return T_val(NaN) end end end @@ -177,6 +180,7 @@ end function (ci::ComponentInterpolator)(x) return SA[ci.itp1(x), ci.itp2(x), ci.itp3(x)] end +Base.eltype(::ComponentInterpolator{T1, T2, T3}) where {T1, T2, T3} = SVector{3, eltype(T1)} function _match_grid_type(g::AbstractRange, ::Type{T}) where {T <: AbstractFloat} return range(T(first(g)), T(last(g)), length = length(g)) From fe403d11a00d377a0d3848fb88b34ff25b494eaa Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Wed, 4 Mar 2026 13:56:57 -0500 Subject: [PATCH 04/27] Update FastInterpolations.jl to v0.4 --- Project.toml | 2 +- src/utility/fastinterpolation.jl | 225 +++++++++++++------------------ test/runtests.jl | 4 +- test/test_utility.jl | 18 +-- 4 files changed, 108 insertions(+), 141 deletions(-) diff --git a/Project.toml b/Project.toml index 4c535fdb7..2a1451034 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,7 @@ Adapt = "4.4" ChunkSplitters = "3" DiffResults = "1" Distributed = "1" -FastInterpolations = "0.3.0" +FastInterpolations = "0.4" ForwardDiff = "1" KernelAbstractions = "0.9" LinearAlgebra = "1" diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl index 1f8429d25..f9b7a348f 100644 --- a/src/utility/fastinterpolation.jl +++ b/src/utility/fastinterpolation.jl @@ -4,14 +4,10 @@ using FastInterpolations using StaticArrays using Adapt -import FastInterpolations: _value_type -# Promote rule to construct the interpolant object natively for SVector. -_value_type(::Type{SVector{N, T}}, ::Type{Tg}) where {N, T, Tg <: AbstractFloat} = SVector{N, promote_type(T, Tg)} - Base.eltype(::FastInterpolations.AbstractInterpolant{Tx, Ty}) where {Tx, Ty} = Ty Base.eltype(::FastInterpolations.AbstractInterpolantND{Tx, Ty, N}) where {Tx, Ty, N} = Ty -@inline getinterp(A, grid1, args...) = getinterp(CartesianGrid, A, grid1, args...) +@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) """ AbstractFieldInterpolator @@ -156,10 +152,6 @@ end Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) -function getinterp_scalar(A, grid1, grid2, grid3, args...) - return getinterp_scalar(CartesianGrid, A, grid1, grid2, grid3, args...) -end - function _get_extrap_mode(bc) if bc == 2 return Extrap(:wrap) @@ -171,17 +163,6 @@ function _get_extrap_mode(bc) end end -struct ComponentInterpolator{T1, T2, T3} - itp1::T1 - itp2::T2 - itp3::T3 -end - -function (ci::ComponentInterpolator)(x) - return SA[ci.itp1(x), ci.itp2(x), ci.itp3(x)] -end -Base.eltype(::ComponentInterpolator{T1, T2, T3}) where {T1, T2, T3} = SVector{3, eltype(T1)} - function _match_grid_type(g::AbstractRange, ::Type{T}) where {T <: AbstractFloat} return range(T(first(g)), T(last(g)), length = length(g)) end @@ -204,56 +185,102 @@ function _fastinterp(grids, A, order, bc) elseif order == 2 return quadratic_interp(matched_grids, A; extrap = extrap_mode) elseif order == 3 - if eltype(A) <: SVector{3} - A1 = [v[1] for v in A] - A2 = [v[2] for v in A] - A3 = [v[3] for v in A] - itp1 = cubic_interp(matched_grids, A1; extrap = extrap_mode) - itp2 = cubic_interp(matched_grids, A2; extrap = extrap_mode) - itp3 = cubic_interp(matched_grids, A3; extrap = extrap_mode) - return ComponentInterpolator(itp1, itp2, itp3) - end return cubic_interp(matched_grids, A; extrap = extrap_mode) else return constant_interp(matched_grids, A; extrap = extrap_mode) end end -function getinterp( - ::Type{<:CartesianGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) +""" + build_interpolator(gridtype, A, grids..., order::Int=1, bc::Int=1) + build_interpolator(A, grids..., order::Int=1, bc::Int=1) + +Return a function for interpolating field array `A` on the given grids. + +# Arguments + + - `gridtype`: `CartesianGrid`, `RectilinearGrid` or `StructuredGrid`. Usually determined by the number of grids. + - `A`: field array. For vector field, the first dimension should be 3 if it's not an SVector wrapper. + - `order::Int=1`: order of interpolation in [1,2,3]. + - `bc::Int=1`: type of boundary conditions, 1 -> NaN, 2 -> periodic, 3 -> Flat. + +# Notes +The input array `A` may be modified in-place for memory optimization. +""" +function build_interpolator( + ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + ) where {T} + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + As = reinterpret(reshape, SVector{3, T}, A) + return build_interpolator(CartesianGrid, As, gridx, gridy, gridz, order, bc) +end + +function build_interpolator( + ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." - else - @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" end - return get_interpolator(CartesianGrid, A, gridx, gridy, gridz, order, bc) + itp = _fastinterp((gridx, gridy, gridz), A, order, bc) + return FieldInterpolator(itp, (gridx, gridy, gridz), bc) end -function getinterp( - ::Type{<:RectilinearGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) +function build_interpolator( + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + ) where {T} + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + As = reinterpret(reshape, SVector{3, T}, A) + return build_interpolator(RectilinearGrid, As, gridx, gridy, gridz, order, bc) +end + +function build_interpolator( + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." - else - @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" end - return get_interpolator(RectilinearGrid, A, gridx, gridy, gridz, order, bc) + if order != 1 + throw(ArgumentError("RectilinearGrid (CartesianNonUniform) only supports order=1 (Linear) interpolation.")) + end + + itp = _fastinterp((gridx, gridy, gridz), A, order, bc) + return FieldInterpolator(itp, (gridx, gridy, gridz), bc) end -function getinterp( - ::Type{<:StructuredGrid}, A, gridr, gridθ, gridϕ, - order::Int = 1, bc::Int = 3 - ) +function build_interpolator( + ::Type{<:StructuredGrid}, A::AbstractArray{T, 4}, + gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 + ) where {T} + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + As = reinterpret(reshape, SVector{3, T}, A) + return build_interpolator(StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) +end + +function build_interpolator( + ::Type{<:StructuredGrid}, A::AbstractArray{T, 3}, + gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 + ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." + end + + is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) + + if is_uniform_r + itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) else - @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + gridϕ, A = _ensure_full_phi(gridϕ, A) + itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) end - return get_interpolator(StructuredGrid, A, gridr, gridθ, gridϕ, order, bc) + + return SphericalFieldInterpolator(itp, (gridr, gridθ, gridϕ), bc) end -function getinterp(::Type{<:CartesianGrid}, A, gridx, gridy, order::Int = 1, bc::Int = 2) +function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 2) if eltype(A) <: SVector @assert ndims(A) == 2 "Inconsistent 2D force field and grid! Expected 2D array of SVectors." As = A @@ -266,27 +293,7 @@ function getinterp(::Type{<:CartesianGrid}, A, gridx, gridy, order::Int = 1, bc: return FieldInterpolator2D(itp, (gridx, gridy), bc) end -function get_interpolator( - ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, - gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) where {T} - As = reinterpret(reshape, SVector{3, T}, A) - return get_interpolator(RectilinearGrid, As, gridx, gridy, gridz, order, bc) -end - -function get_interpolator( - ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, - gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) where {T} - if order != 1 - throw(ArgumentError("RectilinearGrid (CartesianNonUniform) only supports order=1 (Linear) interpolation.")) - end - - itp = _fastinterp((gridx, gridy, gridz), A, order, bc) - return FieldInterpolator(itp, (gridx, gridy, gridz), bc) -end - -function getinterp(::Type{<:CartesianGrid}, A, gridx, order::Int = 1, bc::Int = 3; dir = 1) +function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, order::Int = 1, bc::Int = 3; dir = 1) if eltype(A) <: SVector @assert ndims(A) == 1 "Inconsistent 1D force field and grid! Expected 1D array of SVectors." As = A @@ -296,42 +303,8 @@ function getinterp(::Type{<:CartesianGrid}, A, gridx, order::Int = 1, bc::Int = end itp = _fastinterp((gridx,), As, order, bc) - return FieldInterpolator1D(itp, gridx, bc, dir) -end - -function getinterp_scalar( - ::Type{<:CartesianGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) - return get_interpolator(CartesianGrid, A, gridx, gridy, gridz, order, bc) -end - -function getinterp_scalar( - ::Type{<:RectilinearGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) - return get_interpolator(RectilinearGrid, A, gridx, gridy, gridz, order, bc) -end -function getinterp_scalar( - ::Type{<:StructuredGrid}, A, - gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 - ) - return get_interpolator(StructuredGrid, A, gridr, gridθ, gridϕ, order, bc) -end - -function get_interpolator( - ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, - gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) where {T} - As = reinterpret(reshape, SVector{3, T}, A) - return get_interpolator(CartesianGrid, As, gridx, gridy, gridz, order, bc) -end - -function get_interpolator( - ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, - gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) where {T} - itp = _fastinterp((gridx, gridy, gridz), A, order, bc) - return FieldInterpolator(itp, (gridx, gridy, gridz), bc) + return FieldInterpolator1D(itp, gridx, bc, dir) end function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} @@ -375,32 +348,20 @@ function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} return new_grid_vec, new_A end -function get_interpolator( - ::Type{<:StructuredGrid}, - A::AbstractArray{T, 4}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 - ) where {T} - As = reinterpret(reshape, SVector{3, T}, A) - return get_interpolator(StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) -end - -function get_interpolator( - ::Type{<:StructuredGrid}, - A::AbstractArray{T, 3}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 - ) where {T} - is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) +# Time-dependent field interpolation. - if is_uniform_r - itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) - else - gridϕ, A = _ensure_full_phi(gridϕ, A) - itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) - end +""" + LazyTimeInterpolator{T, F, L} - return SphericalFieldInterpolator(itp, (gridr, gridθ, gridϕ), bc) -end +A callable struct for handling time-dependent fields with lazy loading and linear time interpolation. -# Time-dependent field interpolation. +# Fields +- `times::Vector{T}`: Sorted vector of time points. +- `loader::L`: Function `i -> field` that loads the field at index `i`. +- `buffer::Dict{Int, F}`: Cache for loaded fields. +- `lock::ReentrantLock`: Lock for thread safety. +""" struct LazyTimeInterpolator{T, F, L} <: Function times::Vector{T} loader::L @@ -409,6 +370,7 @@ struct LazyTimeInterpolator{T, F, L} <: Function end function LazyTimeInterpolator(times::AbstractVector, loader::Function) + # Determine the field type by loading the first field f1 = loader(1) return _LazyTimeInterpolator(times, loader, f1) end @@ -422,19 +384,22 @@ function _LazyTimeInterpolator(times::AbstractVector, loader::Function, f1::F) w end function (itp::LazyTimeInterpolator)(x, t) + # Find the time interval [t1, t2] such that t1 <= t <= t2 (assume times is sorted) idx = searchsortedlast(itp.times, t) + # Handle out-of-bounds if idx == 0 - return _get_field!(itp, 1)(x) + return _get_field!(itp, 1)(x) # clamp to start elseif idx >= length(itp.times) - return _get_field!(itp, length(itp.times))(x) + return _get_field!(itp, length(itp.times))(x) # clamp to end end t1 = itp.times[idx] t2 = itp.times[idx + 1] - w = (t - t1) / (t2 - t1) + w = (t - t1) / (t2 - t1) # linear weights + # Load fields (lazily) f1 = _get_field!(itp, idx) f2 = _get_field!(itp, idx + 1) @@ -444,7 +409,9 @@ end function _get_field!(itp::LazyTimeInterpolator, idx::Int) return lock(itp.lock) do if !haskey(itp.buffer, idx) + # Remove far-away indices filter!(p -> abs(p.first - idx) <= 1, itp.buffer) + field = itp.loader(idx) itp.buffer[idx] = field end diff --git a/test/runtests.jl b/test/runtests.jl index 7d11c2f3e..65a68ecff 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -170,7 +170,7 @@ end r, θ, ϕ, B_sph = setup_spherical_field() param = prepare(r, θ, ϕ, zero_E, B_sph; gridtype = StructuredGrid) # Check field interpolation Bz - @test param[4](SA[1.0, 1.0, 1.0])[3] == 9.888387888463716e-9 + @test param[4](SA[1.0, 1.0, 1.0])[3] ≈ 9.888387888463716e-9 # Test Vector inputs for spherical grid r_vec = collect(r) @@ -181,7 +181,7 @@ end r_vec, θ_vec, ϕ_vec, zero_E, B_sph; species = Ion(m = 16, q = 1), gridtype = TP.StructuredGrid ) - @test param_vec[4](SA[1.0, 1.0, 1.0])[3] == 9.888387888463716e-9 + @test param_vec[4](SA[1.0, 1.0, 1.0])[3] ≈ 9.888387888463716e-9 end @testset "analytical field" begin diff --git a/test/test_utility.jl b/test/test_utility.jl index 5c3bba260..27b000f81 100644 --- a/test/test_utility.jl +++ b/test/test_utility.jl @@ -50,23 +50,23 @@ import TestParticle as TP for i in eachindex(x), j in eachindex(y), k in eachindex(z) ] nfunc11 = TP.build_interpolator(n, x, y, z) - @test nfunc11(SA[9, 0, 0]) == 11.85 + @test nfunc11(SA[9, 0, 0]) ≈ 11.85f0 nfunc12 = TP.build_interpolator(n, x, y, z, 1, 2) - @test nfunc12(SA[20, 0, 0]) == 9.5 + @test nfunc12(SA[20, 0, 0]) ≈ 10.5f0 nfunc13 = TP.build_interpolator(n, x, y, z, 1, 3) - @test nfunc13(SA[20, 0, 0]) == 12.0 + @test nfunc13(SA[20, 0, 0]) ≈ 12.0f0 nfunc21 = TP.build_interpolator(n, x, y, z, 2) - @test nfunc21(SA[9, 0, 0]) == 11.898528302013874 + @test nfunc21(SA[9, 0, 0]) ≈ 11.85f0 nfunc22 = TP.build_interpolator(n, x, y, z, 2, 2) - @test nfunc22(SA[20, 0, 0]) == 9.166666686534882 + @test nfunc22(SA[20, 0, 0]) ≈ 10.5f0 nfunc23 = TP.build_interpolator(n, x, y, z, 2, 3) - @test nfunc23(SA[20, 0, 0]) == 12.14705765247345 + @test nfunc23(SA[20, 0, 0]) ≈ 12.0f0 nfunc31 = TP.build_interpolator(n, x, y, z, 3) - @test nfunc31(SA[9, 0, 0]) == 11.882999392215163 + @test nfunc31(SA[9, 0, 0]) ≈ 11.85f0 nfunc32 = TP.build_interpolator(n, x, y, z, 3, 2) - @test nfunc32(SA[20, 0, 0]) == 9.124999547351358 + @test nfunc32(SA[20, 0, 0]) ≈ 10.5f0 nfunc33 = TP.build_interpolator(n, x, y, z, 3, 3) - @test nfunc33(SA[20, 0, 0]) == 12.191176189381315 + @test nfunc33(SA[20, 0, 0]) ≈ 12.0f0 end begin # spherical interpolation From c3283cd5b88edd66f0dd798340aef64e96f02880 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Wed, 4 Mar 2026 14:26:03 -0500 Subject: [PATCH 05/27] perf: suppress bounds check for FieldInterpolators (#484) --- src/utility/interpolation.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/utility/interpolation.jl b/src/utility/interpolation.jl index 0cf81f2ba..90dff4f66 100644 --- a/src/utility/interpolation.jl +++ b/src/utility/interpolation.jl @@ -20,7 +20,7 @@ end const FieldInterpolator3D = FieldInterpolator -function (fi::FieldInterpolator)(xu) +@inbounds function (fi::FieldInterpolator)(xu) return fi.itp(xu[1], xu[2], xu[3]) end @@ -39,7 +39,7 @@ struct FieldInterpolator2D{T} <: AbstractFieldInterpolator itp::T end -function (fi::FieldInterpolator2D)(xu) +@inbounds function (fi::FieldInterpolator2D)(xu) # 2D interpolation usually involves x and y return fi.itp(xu[1], xu[2]) end @@ -330,9 +330,7 @@ function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} if needs_0 selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(new_A, phi_dim, 1) else # needs 2π only - selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim( - A, phi_dim, 1 - ) + selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(A, phi_dim, 1) end end From 3d7c0b9e685518a7ce02aab5453504319f9f6fd4 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Sun, 1 Mar 2026 10:24:15 -0500 Subject: [PATCH 06/27] refactor: substitute Interpolations.jl with FastInterpolations.jl --- Project.toml | 4 +- src/TestParticle.jl | 5 +- src/utility/fastinterpolation.jl | 434 +++++++++++++++++++++++++++++++ 3 files changed, 437 insertions(+), 6 deletions(-) create mode 100644 src/utility/fastinterpolation.jl diff --git a/Project.toml b/Project.toml index 2b7d7b7f9..4c535fdb7 100644 --- a/Project.toml +++ b/Project.toml @@ -11,8 +11,8 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +FastInterpolations = "9ea80cae-fc13-4c00-8066-6eaedb12f34b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa" @@ -38,8 +38,8 @@ Adapt = "4.4" ChunkSplitters = "3" DiffResults = "1" Distributed = "1" +FastInterpolations = "0.3.0" ForwardDiff = "1" -Interpolations = "0.15, 0.16" KernelAbstractions = "0.9" LinearAlgebra = "1" Meshes = "0.55, 0.56" diff --git a/src/TestParticle.jl b/src/TestParticle.jl index f5c1bed22..9ca9970ba 100644 --- a/src/TestParticle.jl +++ b/src/TestParticle.jl @@ -1,9 +1,6 @@ module TestParticle using LinearAlgebra: norm, ×, ⋅, diag, normalize -using Interpolations: interpolate, interpolate!, extrapolate, scale, BSpline, Linear, - Quadratic, Cubic, - Line, OnCell, Periodic, Flat, Gridded using SciMLBase: AbstractODEProblem, AbstractODEFunction, AbstractODESolution, ReturnCode, BasicEnsembleAlgorithm, EnsembleThreads, EnsembleSerial, EnsembleDistributed, DEFAULT_SPECIALIZATION, ODEFunction, ODEProblem, remake, @@ -46,7 +43,7 @@ export EnsembleSerial, EnsembleThreads, EnsembleDistributed, remake include("types.jl") include("utility/utility.jl") -include("utility/interpolation.jl") +include("utility/fastinterpolation.jl") include("sampler.jl") include("prepare.jl") include("gc.jl") diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl new file mode 100644 index 000000000..5f642a046 --- /dev/null +++ b/src/utility/fastinterpolation.jl @@ -0,0 +1,434 @@ +# Field interpolations using FastInterpolations.jl. + +using FastInterpolations +using StaticArrays +using Adapt + +import FastInterpolations: _value_type +# Promote rule to construct the interpolant object natively for SVector. +_value_type(::Type{SVector{N, T}}, ::Type{Tg}) where {N, T, Tg <: AbstractFloat} = SVector{N, promote_type(T, Tg)} + +@inline getinterp(A, grid1, args...) = getinterp(CartesianGrid, A, grid1, args...) + +""" + AbstractFieldInterpolator + +Abstract type for all field interpolators. +""" +abstract type AbstractFieldInterpolator <: Function end + +""" + FieldInterpolator{T, G} + +A callable struct that wraps a 3D interpolation object and its grid. +""" +struct FieldInterpolator{T, G} <: AbstractFieldInterpolator + itp::T + grid::G + bc::Int +end + +_in_bounds(x, gridx) = first(gridx) <= x <= last(gridx) + +function (fi::FieldInterpolator)(xu) + if fi.bc == 1 + if !(_in_bounds(xu[1], fi.grid[1]) && _in_bounds(xu[2], fi.grid[2]) && _in_bounds(xu[3], fi.grid[3])) + T_val = typeof(fi.itp((xu[1], xu[2], xu[3]))) + if T_val <: SVector + return fill(eltype(T_val)(NaN), length(T_val)) + else + return T_val(NaN) + end + end + end + return fi.itp((xu[1], xu[2], xu[3])) +end + +function (fi::FieldInterpolator)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::FieldInterpolator) = FieldInterpolator(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) + +""" + FieldInterpolator2D{T, G} + +A callable struct that wraps a 2D interpolation object. +""" +struct FieldInterpolator2D{T, G} <: AbstractFieldInterpolator + itp::T + grid::G + bc::Int +end + +function (fi::FieldInterpolator2D)(xu) + if fi.bc == 1 + if !(_in_bounds(xu[1], fi.grid[1]) && _in_bounds(xu[2], fi.grid[2])) + T_val = typeof(fi.itp((xu[1], xu[2]))) + if T_val <: SVector + return fill(eltype(T_val)(NaN), length(T_val)) + else + return T_val(NaN) + end + end + end + return fi.itp((xu[1], xu[2])) +end + +function (fi::FieldInterpolator2D)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::FieldInterpolator2D) = FieldInterpolator2D(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) + +""" + FieldInterpolator1D{T, G} + +A callable struct that wraps a 1D interpolation object. +""" +struct FieldInterpolator1D{T, G} <: AbstractFieldInterpolator + itp::T + grid::G + bc::Int + dir::Int +end + +function (fi::FieldInterpolator1D)(xu) + if fi.bc == 1 + if !_in_bounds(xu[fi.dir], fi.grid) + T_val = typeof(fi.itp((xu[fi.dir],))) + if T_val <: SVector + return fill(eltype(T_val)(NaN), length(T_val)) + else + return T_val(NaN) + end + end + end + return fi.itp((xu[fi.dir],)) +end + +function (fi::FieldInterpolator1D)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::FieldInterpolator1D) = FieldInterpolator1D(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc, fi.dir) + +""" + SphericalFieldInterpolator{T, G} + +A callable struct for spherical grid interpolation. +""" +struct SphericalFieldInterpolator{T, G} <: AbstractFieldInterpolator + itp::T + grid::G + bc::Int +end + +function (fi::SphericalFieldInterpolator)(xu) + r_val, θ_val, ϕ_val = cart2sph(xu) + + if fi.bc == 1 + if !(_in_bounds(r_val, fi.grid[1]) && _in_bounds(θ_val, fi.grid[2]) && _in_bounds(ϕ_val, fi.grid[3])) + res_val = fi.itp((r_val, θ_val, ϕ_val)) + if typeof(res_val) <: SVector || length(res_val) > 1 + return fill(eltype(res_val)(NaN), length(res_val)) + else + return typeof(res_val)(NaN) + end + end + end + + res = fi.itp((r_val, θ_val, ϕ_val)) + if typeof(res) <: SVector || length(res) > 1 + Br, Bθ, Bϕ = res + return sph_to_cart_vector(Br, Bθ, Bϕ, θ_val, ϕ_val) + else + return res + end +end + +function (fi::SphericalFieldInterpolator)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) + +function getinterp_scalar(A, grid1, grid2, grid3, args...) + return getinterp_scalar(CartesianGrid, A, grid1, grid2, grid3, args...) +end + +function _get_extrap_mode(bc) + if bc == 2 + return Extrap(:wrap) + elseif bc == 3 + return Extrap(:constant) + else + # For bc == 1, we handle NaN in wrapper and use NoExtrap() for inner to avoid errors + return NoExtrap() + end +end + +struct ComponentInterpolator{T1, T2, T3} + itp1::T1 + itp2::T2 + itp3::T3 +end + +function (ci::ComponentInterpolator)(x) + return SA[ci.itp1(x), ci.itp2(x), ci.itp3(x)] +end + +function _fastinterp(grids, A, order, bc) + extrap_mode = _get_extrap_mode(bc) + if order == 1 + return linear_interp(grids, A; extrap = extrap_mode) + elseif order == 2 + return quadratic_interp(grids, A; extrap = extrap_mode) + elseif order == 3 + if eltype(A) <: SVector{3} + A1 = [v[1] for v in A] + A2 = [v[2] for v in A] + A3 = [v[3] for v in A] + itp1 = cubic_interp(grids, A1; extrap = extrap_mode) + itp2 = cubic_interp(grids, A2; extrap = extrap_mode) + itp3 = cubic_interp(grids, A3; extrap = extrap_mode) + return ComponentInterpolator(itp1, itp2, itp3) + end + return cubic_interp(grids, A; extrap = extrap_mode) + else + return constant_interp(grids, A; extrap = extrap_mode) + end +end + +function getinterp( + ::Type{<:CartesianGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) + if eltype(A) <: SVector + @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." + else + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + end + return get_interpolator(CartesianGrid, A, gridx, gridy, gridz, order, bc) +end + +function getinterp( + ::Type{<:RectilinearGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) + if eltype(A) <: SVector + @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." + else + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + end + return get_interpolator(RectilinearGrid, A, gridx, gridy, gridz, order, bc) +end + +function getinterp( + ::Type{<:StructuredGrid}, A, gridr, gridθ, gridϕ, + order::Int = 1, bc::Int = 3 + ) + if eltype(A) <: SVector + @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." + else + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + end + return get_interpolator(StructuredGrid, A, gridr, gridθ, gridϕ, order, bc) +end + +function getinterp(::Type{<:CartesianGrid}, A, gridx, gridy, order::Int = 1, bc::Int = 2) + if eltype(A) <: SVector + @assert ndims(A) == 2 "Inconsistent 2D force field and grid! Expected 2D array of SVectors." + As = A + else + @assert size(A, 1) == 3 && ndims(A) == 3 "Inconsistent 2D force field and grid!" + As = reinterpret(reshape, SVector{3, eltype(A)}, A) + end + + itp = _fastinterp((gridx, gridy), As, order, bc) + return FieldInterpolator2D(itp, (gridx, gridy), bc) +end + +function get_interpolator( + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, + gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) where {T} + As = reinterpret(reshape, SVector{3, T}, A) + return get_interpolator(RectilinearGrid, As, gridx, gridy, gridz, order, bc) +end + +function get_interpolator( + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, + gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) where {T} + if order != 1 + throw(ArgumentError("RectilinearGrid (CartesianNonUniform) only supports order=1 (Linear) interpolation.")) + end + + itp = _fastinterp((gridx, gridy, gridz), A, order, bc) + return FieldInterpolator(itp, (gridx, gridy, gridz), bc) +end + +function getinterp(::Type{<:CartesianGrid}, A, gridx, order::Int = 1, bc::Int = 3; dir = 1) + if eltype(A) <: SVector + @assert ndims(A) == 1 "Inconsistent 1D force field and grid! Expected 1D array of SVectors." + As = A + else + @assert size(A, 1) == 3 && ndims(A) == 2 "Inconsistent 1D force field and grid!" + As = reinterpret(reshape, SVector{3, eltype(A)}, A) + end + + itp = _fastinterp((gridx,), As, order, bc) + return FieldInterpolator1D(itp, gridx, bc, dir) +end + +function getinterp_scalar( + ::Type{<:CartesianGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) + return get_interpolator(CartesianGrid, A, gridx, gridy, gridz, order, bc) +end + +function getinterp_scalar( + ::Type{<:RectilinearGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) + return get_interpolator(RectilinearGrid, A, gridx, gridy, gridz, order, bc) +end + +function getinterp_scalar( + ::Type{<:StructuredGrid}, A, + gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 + ) + return get_interpolator(StructuredGrid, A, gridr, gridθ, gridϕ, order, bc) +end + +function get_interpolator( + ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, + gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) where {T} + As = reinterpret(reshape, SVector{3, T}, A) + return get_interpolator(CartesianGrid, As, gridx, gridy, gridz, order, bc) +end + +function get_interpolator( + ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, + gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) where {T} + itp = _fastinterp((gridx, gridy, gridz), A, order, bc) + return FieldInterpolator(itp, (gridx, gridy, gridz), bc) +end + +function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} + min_phi, max_phi = extrema(gridϕ) + needs_0 = !isapprox(min_phi, 0, atol = 1.0e-5) + needs_2pi = !isapprox(max_phi, 2π, atol = 1.0e-5) + + if !needs_0 && !needs_2pi + return gridϕ, A + end + + new_grid_vec = collect(gridϕ) + if needs_0 + pushfirst!(new_grid_vec, 0.0) + end + if needs_2pi + push!(new_grid_vec, 2π) + end + + phi_dim = N + new_A = Array{T, N}(undef, (size(A)[1:(end - 1)]..., length(new_grid_vec))) + + start_idx = needs_0 ? 2 : 1 + end_idx = start_idx + length(gridϕ) - 1 + + selectdim(new_A, phi_dim, start_idx:end_idx) .= A + + if needs_0 + src_idx_for_0 = size(A, phi_dim) + selectdim(new_A, phi_dim, 1) .= selectdim(A, phi_dim, src_idx_for_0) + end + + if needs_2pi + if needs_0 + selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(new_A, phi_dim, 1) + else + selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(A, phi_dim, 1) + end + end + + return new_grid_vec, new_A +end + +function get_interpolator( + ::Type{<:StructuredGrid}, + A::AbstractArray{T, 4}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 + ) where {T} + As = reinterpret(reshape, SVector{3, T}, A) + return get_interpolator(StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) +end + +function get_interpolator( + ::Type{<:StructuredGrid}, + A::AbstractArray{T, 3}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 + ) where {T} + is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) + + if is_uniform_r + itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) + else + gridϕ, A = _ensure_full_phi(gridϕ, A) + itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) + end + + return SphericalFieldInterpolator(itp, (gridr, gridθ, gridϕ), bc) +end + +# Time-dependent field interpolation. + +struct LazyTimeInterpolator{T, F, L} <: Function + times::Vector{T} + loader::L + buffer::Dict{Int, F} + lock::ReentrantLock +end + +function LazyTimeInterpolator(times::AbstractVector, loader::Function) + f1 = loader(1) + return _LazyTimeInterpolator(times, loader, f1) +end + +function _LazyTimeInterpolator(times::AbstractVector, loader::Function, f1::F) where {F} + buffer = Dict{Int, F}(1 => f1) + lock = ReentrantLock() + return LazyTimeInterpolator{eltype(times), F, typeof(loader)}( + times, loader, buffer, lock + ) +end + +function (itp::LazyTimeInterpolator)(x, t) + idx = searchsortedlast(itp.times, t) + + if idx == 0 + return _get_field!(itp, 1)(x) + elseif idx >= length(itp.times) + return _get_field!(itp, length(itp.times))(x) + end + + t1 = itp.times[idx] + t2 = itp.times[idx + 1] + + w = (t - t1) / (t2 - t1) + + f1 = _get_field!(itp, idx) + f2 = _get_field!(itp, idx + 1) + + return (1 - w) * f1(x) + w * f2(x) +end + +function _get_field!(itp::LazyTimeInterpolator, idx::Int) + return lock(itp.lock) do + if !haskey(itp.buffer, idx) + filter!(p -> abs(p.first - idx) <= 1, itp.buffer) + field = itp.loader(idx) + itp.buffer[idx] = field + end + return itp.buffer[idx] + end +end From e0b09a9158b6fc335b9600adf2569d768ee65dae Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Sun, 1 Mar 2026 22:50:15 -0500 Subject: [PATCH 07/27] Try to fix the Float32 vs Float64 bug --- src/utility/fastinterpolation.jl | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl index 5f642a046..a300e1bd3 100644 --- a/src/utility/fastinterpolation.jl +++ b/src/utility/fastinterpolation.jl @@ -178,25 +178,40 @@ function (ci::ComponentInterpolator)(x) return SA[ci.itp1(x), ci.itp2(x), ci.itp3(x)] end +function _match_grid_type(g::AbstractRange, ::Type{T}) where {T <: AbstractFloat} + return range(T(first(g)), T(last(g)), length = length(g)) +end +function _match_grid_type(g::AbstractVector, ::Type{T}) where {T <: AbstractFloat} + return T.(g) +end +function _match_grid_type(g::Tuple, ::Type{T}) where {T <: AbstractFloat} + return T.(g) +end + function _fastinterp(grids, A, order, bc) + T_A = eltype(A) + T_F = T_A <: SVector ? eltype(T_A) : T_A + T_F = T_F <: AbstractFloat ? T_F : Float64 + matched_grids = map(g -> _match_grid_type(g, T_F), grids) + extrap_mode = _get_extrap_mode(bc) if order == 1 - return linear_interp(grids, A; extrap = extrap_mode) + return linear_interp(matched_grids, A; extrap = extrap_mode) elseif order == 2 - return quadratic_interp(grids, A; extrap = extrap_mode) + return quadratic_interp(matched_grids, A; extrap = extrap_mode) elseif order == 3 if eltype(A) <: SVector{3} A1 = [v[1] for v in A] A2 = [v[2] for v in A] A3 = [v[3] for v in A] - itp1 = cubic_interp(grids, A1; extrap = extrap_mode) - itp2 = cubic_interp(grids, A2; extrap = extrap_mode) - itp3 = cubic_interp(grids, A3; extrap = extrap_mode) + itp1 = cubic_interp(matched_grids, A1; extrap = extrap_mode) + itp2 = cubic_interp(matched_grids, A2; extrap = extrap_mode) + itp3 = cubic_interp(matched_grids, A3; extrap = extrap_mode) return ComponentInterpolator(itp1, itp2, itp3) end - return cubic_interp(grids, A; extrap = extrap_mode) + return cubic_interp(matched_grids, A; extrap = extrap_mode) else - return constant_interp(grids, A; extrap = extrap_mode) + return constant_interp(matched_grids, A; extrap = extrap_mode) end end From 5edd96a4b4d9cd67087dc8ad1b41b583edbdeaf9 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Sun, 1 Mar 2026 23:12:19 -0500 Subject: [PATCH 08/27] fix out-of-bound return type --- src/utility/fastinterpolation.jl | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl index a300e1bd3..1f8429d25 100644 --- a/src/utility/fastinterpolation.jl +++ b/src/utility/fastinterpolation.jl @@ -8,6 +8,9 @@ import FastInterpolations: _value_type # Promote rule to construct the interpolant object natively for SVector. _value_type(::Type{SVector{N, T}}, ::Type{Tg}) where {N, T, Tg <: AbstractFloat} = SVector{N, promote_type(T, Tg)} +Base.eltype(::FastInterpolations.AbstractInterpolant{Tx, Ty}) where {Tx, Ty} = Ty +Base.eltype(::FastInterpolations.AbstractInterpolantND{Tx, Ty, N}) where {Tx, Ty, N} = Ty + @inline getinterp(A, grid1, args...) = getinterp(CartesianGrid, A, grid1, args...) """ @@ -33,9 +36,9 @@ _in_bounds(x, gridx) = first(gridx) <= x <= last(gridx) function (fi::FieldInterpolator)(xu) if fi.bc == 1 if !(_in_bounds(xu[1], fi.grid[1]) && _in_bounds(xu[2], fi.grid[2]) && _in_bounds(xu[3], fi.grid[3])) - T_val = typeof(fi.itp((xu[1], xu[2], xu[3]))) + T_val = eltype(fi.itp) if T_val <: SVector - return fill(eltype(T_val)(NaN), length(T_val)) + return T_val(ntuple(_ -> NaN, Val(length(T_val)))) else return T_val(NaN) end @@ -64,9 +67,9 @@ end function (fi::FieldInterpolator2D)(xu) if fi.bc == 1 if !(_in_bounds(xu[1], fi.grid[1]) && _in_bounds(xu[2], fi.grid[2])) - T_val = typeof(fi.itp((xu[1], xu[2]))) + T_val = eltype(fi.itp) if T_val <: SVector - return fill(eltype(T_val)(NaN), length(T_val)) + return T_val(ntuple(_ -> NaN, Val(length(T_val)))) else return T_val(NaN) end @@ -96,9 +99,9 @@ end function (fi::FieldInterpolator1D)(xu) if fi.bc == 1 if !_in_bounds(xu[fi.dir], fi.grid) - T_val = typeof(fi.itp((xu[fi.dir],))) + T_val = eltype(fi.itp) if T_val <: SVector - return fill(eltype(T_val)(NaN), length(T_val)) + return T_val(ntuple(_ -> NaN, Val(length(T_val)))) else return T_val(NaN) end @@ -129,11 +132,11 @@ function (fi::SphericalFieldInterpolator)(xu) if fi.bc == 1 if !(_in_bounds(r_val, fi.grid[1]) && _in_bounds(θ_val, fi.grid[2]) && _in_bounds(ϕ_val, fi.grid[3])) - res_val = fi.itp((r_val, θ_val, ϕ_val)) - if typeof(res_val) <: SVector || length(res_val) > 1 - return fill(eltype(res_val)(NaN), length(res_val)) + T_val = eltype(fi.itp) + if T_val <: SVector + return T_val(ntuple(_ -> NaN, Val(length(T_val)))) else - return typeof(res_val)(NaN) + return T_val(NaN) end end end @@ -177,6 +180,7 @@ end function (ci::ComponentInterpolator)(x) return SA[ci.itp1(x), ci.itp2(x), ci.itp3(x)] end +Base.eltype(::ComponentInterpolator{T1, T2, T3}) where {T1, T2, T3} = SVector{3, eltype(T1)} function _match_grid_type(g::AbstractRange, ::Type{T}) where {T <: AbstractFloat} return range(T(first(g)), T(last(g)), length = length(g)) From 13ac8874e645fa382e6c5750b626ba16abf67c2f Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Wed, 4 Mar 2026 13:56:57 -0500 Subject: [PATCH 09/27] Update FastInterpolations.jl to v0.4 --- Project.toml | 2 +- src/utility/fastinterpolation.jl | 225 +++++++++++++------------------ test/runtests.jl | 4 +- test/test_utility.jl | 18 +-- 4 files changed, 108 insertions(+), 141 deletions(-) diff --git a/Project.toml b/Project.toml index 4c535fdb7..2a1451034 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,7 @@ Adapt = "4.4" ChunkSplitters = "3" DiffResults = "1" Distributed = "1" -FastInterpolations = "0.3.0" +FastInterpolations = "0.4" ForwardDiff = "1" KernelAbstractions = "0.9" LinearAlgebra = "1" diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl index 1f8429d25..f9b7a348f 100644 --- a/src/utility/fastinterpolation.jl +++ b/src/utility/fastinterpolation.jl @@ -4,14 +4,10 @@ using FastInterpolations using StaticArrays using Adapt -import FastInterpolations: _value_type -# Promote rule to construct the interpolant object natively for SVector. -_value_type(::Type{SVector{N, T}}, ::Type{Tg}) where {N, T, Tg <: AbstractFloat} = SVector{N, promote_type(T, Tg)} - Base.eltype(::FastInterpolations.AbstractInterpolant{Tx, Ty}) where {Tx, Ty} = Ty Base.eltype(::FastInterpolations.AbstractInterpolantND{Tx, Ty, N}) where {Tx, Ty, N} = Ty -@inline getinterp(A, grid1, args...) = getinterp(CartesianGrid, A, grid1, args...) +@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) """ AbstractFieldInterpolator @@ -156,10 +152,6 @@ end Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) -function getinterp_scalar(A, grid1, grid2, grid3, args...) - return getinterp_scalar(CartesianGrid, A, grid1, grid2, grid3, args...) -end - function _get_extrap_mode(bc) if bc == 2 return Extrap(:wrap) @@ -171,17 +163,6 @@ function _get_extrap_mode(bc) end end -struct ComponentInterpolator{T1, T2, T3} - itp1::T1 - itp2::T2 - itp3::T3 -end - -function (ci::ComponentInterpolator)(x) - return SA[ci.itp1(x), ci.itp2(x), ci.itp3(x)] -end -Base.eltype(::ComponentInterpolator{T1, T2, T3}) where {T1, T2, T3} = SVector{3, eltype(T1)} - function _match_grid_type(g::AbstractRange, ::Type{T}) where {T <: AbstractFloat} return range(T(first(g)), T(last(g)), length = length(g)) end @@ -204,56 +185,102 @@ function _fastinterp(grids, A, order, bc) elseif order == 2 return quadratic_interp(matched_grids, A; extrap = extrap_mode) elseif order == 3 - if eltype(A) <: SVector{3} - A1 = [v[1] for v in A] - A2 = [v[2] for v in A] - A3 = [v[3] for v in A] - itp1 = cubic_interp(matched_grids, A1; extrap = extrap_mode) - itp2 = cubic_interp(matched_grids, A2; extrap = extrap_mode) - itp3 = cubic_interp(matched_grids, A3; extrap = extrap_mode) - return ComponentInterpolator(itp1, itp2, itp3) - end return cubic_interp(matched_grids, A; extrap = extrap_mode) else return constant_interp(matched_grids, A; extrap = extrap_mode) end end -function getinterp( - ::Type{<:CartesianGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) +""" + build_interpolator(gridtype, A, grids..., order::Int=1, bc::Int=1) + build_interpolator(A, grids..., order::Int=1, bc::Int=1) + +Return a function for interpolating field array `A` on the given grids. + +# Arguments + + - `gridtype`: `CartesianGrid`, `RectilinearGrid` or `StructuredGrid`. Usually determined by the number of grids. + - `A`: field array. For vector field, the first dimension should be 3 if it's not an SVector wrapper. + - `order::Int=1`: order of interpolation in [1,2,3]. + - `bc::Int=1`: type of boundary conditions, 1 -> NaN, 2 -> periodic, 3 -> Flat. + +# Notes +The input array `A` may be modified in-place for memory optimization. +""" +function build_interpolator( + ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + ) where {T} + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + As = reinterpret(reshape, SVector{3, T}, A) + return build_interpolator(CartesianGrid, As, gridx, gridy, gridz, order, bc) +end + +function build_interpolator( + ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." - else - @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" end - return get_interpolator(CartesianGrid, A, gridx, gridy, gridz, order, bc) + itp = _fastinterp((gridx, gridy, gridz), A, order, bc) + return FieldInterpolator(itp, (gridx, gridy, gridz), bc) end -function getinterp( - ::Type{<:RectilinearGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) +function build_interpolator( + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + ) where {T} + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + As = reinterpret(reshape, SVector{3, T}, A) + return build_interpolator(RectilinearGrid, As, gridx, gridy, gridz, order, bc) +end + +function build_interpolator( + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." - else - @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" end - return get_interpolator(RectilinearGrid, A, gridx, gridy, gridz, order, bc) + if order != 1 + throw(ArgumentError("RectilinearGrid (CartesianNonUniform) only supports order=1 (Linear) interpolation.")) + end + + itp = _fastinterp((gridx, gridy, gridz), A, order, bc) + return FieldInterpolator(itp, (gridx, gridy, gridz), bc) end -function getinterp( - ::Type{<:StructuredGrid}, A, gridr, gridθ, gridϕ, - order::Int = 1, bc::Int = 3 - ) +function build_interpolator( + ::Type{<:StructuredGrid}, A::AbstractArray{T, 4}, + gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 + ) where {T} + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + As = reinterpret(reshape, SVector{3, T}, A) + return build_interpolator(StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) +end + +function build_interpolator( + ::Type{<:StructuredGrid}, A::AbstractArray{T, 3}, + gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 + ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." + end + + is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) + + if is_uniform_r + itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) else - @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + gridϕ, A = _ensure_full_phi(gridϕ, A) + itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) end - return get_interpolator(StructuredGrid, A, gridr, gridθ, gridϕ, order, bc) + + return SphericalFieldInterpolator(itp, (gridr, gridθ, gridϕ), bc) end -function getinterp(::Type{<:CartesianGrid}, A, gridx, gridy, order::Int = 1, bc::Int = 2) +function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 2) if eltype(A) <: SVector @assert ndims(A) == 2 "Inconsistent 2D force field and grid! Expected 2D array of SVectors." As = A @@ -266,27 +293,7 @@ function getinterp(::Type{<:CartesianGrid}, A, gridx, gridy, order::Int = 1, bc: return FieldInterpolator2D(itp, (gridx, gridy), bc) end -function get_interpolator( - ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, - gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) where {T} - As = reinterpret(reshape, SVector{3, T}, A) - return get_interpolator(RectilinearGrid, As, gridx, gridy, gridz, order, bc) -end - -function get_interpolator( - ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, - gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) where {T} - if order != 1 - throw(ArgumentError("RectilinearGrid (CartesianNonUniform) only supports order=1 (Linear) interpolation.")) - end - - itp = _fastinterp((gridx, gridy, gridz), A, order, bc) - return FieldInterpolator(itp, (gridx, gridy, gridz), bc) -end - -function getinterp(::Type{<:CartesianGrid}, A, gridx, order::Int = 1, bc::Int = 3; dir = 1) +function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, order::Int = 1, bc::Int = 3; dir = 1) if eltype(A) <: SVector @assert ndims(A) == 1 "Inconsistent 1D force field and grid! Expected 1D array of SVectors." As = A @@ -296,42 +303,8 @@ function getinterp(::Type{<:CartesianGrid}, A, gridx, order::Int = 1, bc::Int = end itp = _fastinterp((gridx,), As, order, bc) - return FieldInterpolator1D(itp, gridx, bc, dir) -end - -function getinterp_scalar( - ::Type{<:CartesianGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) - return get_interpolator(CartesianGrid, A, gridx, gridy, gridz, order, bc) -end - -function getinterp_scalar( - ::Type{<:RectilinearGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) - return get_interpolator(RectilinearGrid, A, gridx, gridy, gridz, order, bc) -end -function getinterp_scalar( - ::Type{<:StructuredGrid}, A, - gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 - ) - return get_interpolator(StructuredGrid, A, gridr, gridθ, gridϕ, order, bc) -end - -function get_interpolator( - ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, - gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) where {T} - As = reinterpret(reshape, SVector{3, T}, A) - return get_interpolator(CartesianGrid, As, gridx, gridy, gridz, order, bc) -end - -function get_interpolator( - ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, - gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) where {T} - itp = _fastinterp((gridx, gridy, gridz), A, order, bc) - return FieldInterpolator(itp, (gridx, gridy, gridz), bc) + return FieldInterpolator1D(itp, gridx, bc, dir) end function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} @@ -375,32 +348,20 @@ function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} return new_grid_vec, new_A end -function get_interpolator( - ::Type{<:StructuredGrid}, - A::AbstractArray{T, 4}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 - ) where {T} - As = reinterpret(reshape, SVector{3, T}, A) - return get_interpolator(StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) -end - -function get_interpolator( - ::Type{<:StructuredGrid}, - A::AbstractArray{T, 3}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 - ) where {T} - is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) +# Time-dependent field interpolation. - if is_uniform_r - itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) - else - gridϕ, A = _ensure_full_phi(gridϕ, A) - itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) - end +""" + LazyTimeInterpolator{T, F, L} - return SphericalFieldInterpolator(itp, (gridr, gridθ, gridϕ), bc) -end +A callable struct for handling time-dependent fields with lazy loading and linear time interpolation. -# Time-dependent field interpolation. +# Fields +- `times::Vector{T}`: Sorted vector of time points. +- `loader::L`: Function `i -> field` that loads the field at index `i`. +- `buffer::Dict{Int, F}`: Cache for loaded fields. +- `lock::ReentrantLock`: Lock for thread safety. +""" struct LazyTimeInterpolator{T, F, L} <: Function times::Vector{T} loader::L @@ -409,6 +370,7 @@ struct LazyTimeInterpolator{T, F, L} <: Function end function LazyTimeInterpolator(times::AbstractVector, loader::Function) + # Determine the field type by loading the first field f1 = loader(1) return _LazyTimeInterpolator(times, loader, f1) end @@ -422,19 +384,22 @@ function _LazyTimeInterpolator(times::AbstractVector, loader::Function, f1::F) w end function (itp::LazyTimeInterpolator)(x, t) + # Find the time interval [t1, t2] such that t1 <= t <= t2 (assume times is sorted) idx = searchsortedlast(itp.times, t) + # Handle out-of-bounds if idx == 0 - return _get_field!(itp, 1)(x) + return _get_field!(itp, 1)(x) # clamp to start elseif idx >= length(itp.times) - return _get_field!(itp, length(itp.times))(x) + return _get_field!(itp, length(itp.times))(x) # clamp to end end t1 = itp.times[idx] t2 = itp.times[idx + 1] - w = (t - t1) / (t2 - t1) + w = (t - t1) / (t2 - t1) # linear weights + # Load fields (lazily) f1 = _get_field!(itp, idx) f2 = _get_field!(itp, idx + 1) @@ -444,7 +409,9 @@ end function _get_field!(itp::LazyTimeInterpolator, idx::Int) return lock(itp.lock) do if !haskey(itp.buffer, idx) + # Remove far-away indices filter!(p -> abs(p.first - idx) <= 1, itp.buffer) + field = itp.loader(idx) itp.buffer[idx] = field end diff --git a/test/runtests.jl b/test/runtests.jl index 7d11c2f3e..65a68ecff 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -170,7 +170,7 @@ end r, θ, ϕ, B_sph = setup_spherical_field() param = prepare(r, θ, ϕ, zero_E, B_sph; gridtype = StructuredGrid) # Check field interpolation Bz - @test param[4](SA[1.0, 1.0, 1.0])[3] == 9.888387888463716e-9 + @test param[4](SA[1.0, 1.0, 1.0])[3] ≈ 9.888387888463716e-9 # Test Vector inputs for spherical grid r_vec = collect(r) @@ -181,7 +181,7 @@ end r_vec, θ_vec, ϕ_vec, zero_E, B_sph; species = Ion(m = 16, q = 1), gridtype = TP.StructuredGrid ) - @test param_vec[4](SA[1.0, 1.0, 1.0])[3] == 9.888387888463716e-9 + @test param_vec[4](SA[1.0, 1.0, 1.0])[3] ≈ 9.888387888463716e-9 end @testset "analytical field" begin diff --git a/test/test_utility.jl b/test/test_utility.jl index 5c3bba260..27b000f81 100644 --- a/test/test_utility.jl +++ b/test/test_utility.jl @@ -50,23 +50,23 @@ import TestParticle as TP for i in eachindex(x), j in eachindex(y), k in eachindex(z) ] nfunc11 = TP.build_interpolator(n, x, y, z) - @test nfunc11(SA[9, 0, 0]) == 11.85 + @test nfunc11(SA[9, 0, 0]) ≈ 11.85f0 nfunc12 = TP.build_interpolator(n, x, y, z, 1, 2) - @test nfunc12(SA[20, 0, 0]) == 9.5 + @test nfunc12(SA[20, 0, 0]) ≈ 10.5f0 nfunc13 = TP.build_interpolator(n, x, y, z, 1, 3) - @test nfunc13(SA[20, 0, 0]) == 12.0 + @test nfunc13(SA[20, 0, 0]) ≈ 12.0f0 nfunc21 = TP.build_interpolator(n, x, y, z, 2) - @test nfunc21(SA[9, 0, 0]) == 11.898528302013874 + @test nfunc21(SA[9, 0, 0]) ≈ 11.85f0 nfunc22 = TP.build_interpolator(n, x, y, z, 2, 2) - @test nfunc22(SA[20, 0, 0]) == 9.166666686534882 + @test nfunc22(SA[20, 0, 0]) ≈ 10.5f0 nfunc23 = TP.build_interpolator(n, x, y, z, 2, 3) - @test nfunc23(SA[20, 0, 0]) == 12.14705765247345 + @test nfunc23(SA[20, 0, 0]) ≈ 12.0f0 nfunc31 = TP.build_interpolator(n, x, y, z, 3) - @test nfunc31(SA[9, 0, 0]) == 11.882999392215163 + @test nfunc31(SA[9, 0, 0]) ≈ 11.85f0 nfunc32 = TP.build_interpolator(n, x, y, z, 3, 2) - @test nfunc32(SA[20, 0, 0]) == 9.124999547351358 + @test nfunc32(SA[20, 0, 0]) ≈ 10.5f0 nfunc33 = TP.build_interpolator(n, x, y, z, 3, 3) - @test nfunc33(SA[20, 0, 0]) == 12.191176189381315 + @test nfunc33(SA[20, 0, 0]) ≈ 12.0f0 end begin # spherical interpolation From af413c5392098d95deeb6031eac86a2229a66415 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Wed, 4 Mar 2026 16:41:58 -0500 Subject: [PATCH 10/27] Temporary version before FastInterpolations.jl supports constant extrapolation --- src/TestParticle.jl | 2 + src/utility/fastinterpolation.jl | 183 +++++++++++-------------------- src/utility/interpolation.jl | 28 +++-- test/test_utility.jl | 15 ++- 4 files changed, 94 insertions(+), 134 deletions(-) diff --git a/src/TestParticle.jl b/src/TestParticle.jl index 9ca9970ba..aa3ce2519 100644 --- a/src/TestParticle.jl +++ b/src/TestParticle.jl @@ -1,6 +1,8 @@ module TestParticle using LinearAlgebra: norm, ×, ⋅, diag, normalize +using FastInterpolations: linear_interp, quadratic_interp, cubic_interp, constant_interp, + Extrap, NoExtrap using SciMLBase: AbstractODEProblem, AbstractODEFunction, AbstractODESolution, ReturnCode, BasicEnsembleAlgorithm, EnsembleThreads, EnsembleSerial, EnsembleDistributed, DEFAULT_SPECIALIZATION, ODEFunction, ODEProblem, remake, diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl index f9b7a348f..4a2cc5d33 100644 --- a/src/utility/fastinterpolation.jl +++ b/src/utility/fastinterpolation.jl @@ -1,13 +1,4 @@ -# Field interpolations using FastInterpolations.jl. - -using FastInterpolations -using StaticArrays -using Adapt - -Base.eltype(::FastInterpolations.AbstractInterpolant{Tx, Ty}) where {Tx, Ty} = Ty -Base.eltype(::FastInterpolations.AbstractInterpolantND{Tx, Ty, N}) where {Tx, Ty, N} = Ty - -@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) +# Field interpolations. """ AbstractFieldInterpolator @@ -17,29 +8,17 @@ Abstract type for all field interpolators. abstract type AbstractFieldInterpolator <: Function end """ - FieldInterpolator{T, G} + FieldInterpolator{T} -A callable struct that wraps a 3D interpolation object and its grid. +A callable struct that wraps a 3D interpolation object. """ -struct FieldInterpolator{T, G} <: AbstractFieldInterpolator +struct FieldInterpolator{T} <: AbstractFieldInterpolator itp::T - grid::G - bc::Int end -_in_bounds(x, gridx) = first(gridx) <= x <= last(gridx) - -function (fi::FieldInterpolator)(xu) - if fi.bc == 1 - if !(_in_bounds(xu[1], fi.grid[1]) && _in_bounds(xu[2], fi.grid[2]) && _in_bounds(xu[3], fi.grid[3])) - T_val = eltype(fi.itp) - if T_val <: SVector - return T_val(ntuple(_ -> NaN, Val(length(T_val)))) - else - return T_val(NaN) - end - end - end +const FieldInterpolator3D = FieldInterpolator + +@inbounds function (fi::FieldInterpolator)(xu) return fi.itp((xu[1], xu[2], xu[3])) end @@ -47,30 +26,19 @@ function (fi::FieldInterpolator)(xu, t) return fi(xu) end -Adapt.adapt_structure(to, fi::FieldInterpolator) = FieldInterpolator(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) +Adapt.adapt_structure(to, fi::FieldInterpolator) = FieldInterpolator(Adapt.adapt(to, fi.itp)) """ - FieldInterpolator2D{T, G} + FieldInterpolator2D{T} A callable struct that wraps a 2D interpolation object. """ -struct FieldInterpolator2D{T, G} <: AbstractFieldInterpolator +struct FieldInterpolator2D{T} <: AbstractFieldInterpolator itp::T - grid::G - bc::Int end -function (fi::FieldInterpolator2D)(xu) - if fi.bc == 1 - if !(_in_bounds(xu[1], fi.grid[1]) && _in_bounds(xu[2], fi.grid[2])) - T_val = eltype(fi.itp) - if T_val <: SVector - return T_val(ntuple(_ -> NaN, Val(length(T_val)))) - else - return T_val(NaN) - end - end - end +@inbounds function (fi::FieldInterpolator2D)(xu) + # 2D interpolation usually involves x and y return fi.itp((xu[1], xu[2])) end @@ -78,31 +46,19 @@ function (fi::FieldInterpolator2D)(xu, t) return fi(xu) end -Adapt.adapt_structure(to, fi::FieldInterpolator2D) = FieldInterpolator2D(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) +Adapt.adapt_structure(to, fi::FieldInterpolator2D) = FieldInterpolator2D(Adapt.adapt(to, fi.itp)) """ - FieldInterpolator1D{T, G} + FieldInterpolator1D{T} A callable struct that wraps a 1D interpolation object. """ -struct FieldInterpolator1D{T, G} <: AbstractFieldInterpolator +struct FieldInterpolator1D{T} <: AbstractFieldInterpolator itp::T - grid::G - bc::Int dir::Int end -function (fi::FieldInterpolator1D)(xu) - if fi.bc == 1 - if !_in_bounds(xu[fi.dir], fi.grid) - T_val = eltype(fi.itp) - if T_val <: SVector - return T_val(ntuple(_ -> NaN, Val(length(T_val)))) - else - return T_val(NaN) - end - end - end +@inbounds function (fi::FieldInterpolator1D)(xu) return fi.itp((xu[fi.dir],)) end @@ -110,37 +66,24 @@ function (fi::FieldInterpolator1D)(xu, t) return fi(xu) end -Adapt.adapt_structure(to, fi::FieldInterpolator1D) = FieldInterpolator1D(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc, fi.dir) +Adapt.adapt_structure(to, fi::FieldInterpolator1D) = FieldInterpolator1D(Adapt.adapt(to, fi.itp), fi.dir) """ - SphericalFieldInterpolator{T, G} + SphericalFieldInterpolator{T} -A callable struct for spherical grid interpolation. +A callable struct for spherical grid interpolation (scalar or combined vector). """ -struct SphericalFieldInterpolator{T, G} <: AbstractFieldInterpolator +struct SphericalFieldInterpolator{T} <: AbstractFieldInterpolator itp::T - grid::G - bc::Int end function (fi::SphericalFieldInterpolator)(xu) - r_val, θ_val, ϕ_val = cart2sph(xu) - - if fi.bc == 1 - if !(_in_bounds(r_val, fi.grid[1]) && _in_bounds(θ_val, fi.grid[2]) && _in_bounds(ϕ_val, fi.grid[3])) - T_val = eltype(fi.itp) - if T_val <: SVector - return T_val(ntuple(_ -> NaN, Val(length(T_val)))) - else - return T_val(NaN) - end - end - end - - res = fi.itp((r_val, θ_val, ϕ_val)) - if typeof(res) <: SVector || length(res) > 1 + rθϕ = cart2sph(xu) + res = fi.itp(rθϕ) + if length(res) > 1 + # Convert vector result from spherical to cartesian basis Br, Bθ, Bϕ = res - return sph_to_cart_vector(Br, Bθ, Bϕ, θ_val, ϕ_val) + return sph_to_cart_vector(Br, Bθ, Bϕ, rθϕ[2], rθϕ[3]) else return res end @@ -150,7 +93,7 @@ function (fi::SphericalFieldInterpolator)(xu, t) return fi(xu) end -Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) +Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp)) function _get_extrap_mode(bc) if bc == 2 @@ -158,42 +101,32 @@ function _get_extrap_mode(bc) elseif bc == 3 return Extrap(:constant) else - # For bc == 1, we handle NaN in wrapper and use NoExtrap() for inner to avoid errors + # TODO: bc == 1 (NaN outside domain) requires native FastInterpolations.jl support. + # Once available, replace this with the appropriate Extrap mode and remove the + # manual bounds-checking code that was previously in the FieldInterpolator call methods. return NoExtrap() end end -function _match_grid_type(g::AbstractRange, ::Type{T}) where {T <: AbstractFloat} - return range(T(first(g)), T(last(g)), length = length(g)) -end -function _match_grid_type(g::AbstractVector, ::Type{T}) where {T <: AbstractFloat} - return T.(g) -end -function _match_grid_type(g::Tuple, ::Type{T}) where {T <: AbstractFloat} - return T.(g) -end function _fastinterp(grids, A, order, bc) - T_A = eltype(A) - T_F = T_A <: SVector ? eltype(T_A) : T_A - T_F = T_F <: AbstractFloat ? T_F : Float64 - matched_grids = map(g -> _match_grid_type(g, T_F), grids) - extrap_mode = _get_extrap_mode(bc) if order == 1 - return linear_interp(matched_grids, A; extrap = extrap_mode) + return linear_interp(grids, A; extrap = extrap_mode) elseif order == 2 - return quadratic_interp(matched_grids, A; extrap = extrap_mode) + return quadratic_interp(grids, A; extrap = extrap_mode) elseif order == 3 - return cubic_interp(matched_grids, A; extrap = extrap_mode) + return cubic_interp(grids, A; extrap = extrap_mode) else - return constant_interp(matched_grids, A; extrap = extrap_mode) + return constant_interp(grids, A; extrap = extrap_mode) end end +@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) + """ - build_interpolator(gridtype, A, grids..., order::Int=1, bc::Int=1) - build_interpolator(A, grids..., order::Int=1, bc::Int=1) + build_interpolator(gridtype, A, grids..., order::Int=1, bc::Int=3) + build_interpolator(A, grids..., order::Int=1, bc::Int=3) Return a function for interpolating field array `A` on the given grids. @@ -202,14 +135,15 @@ Return a function for interpolating field array `A` on the given grids. - `gridtype`: `CartesianGrid`, `RectilinearGrid` or `StructuredGrid`. Usually determined by the number of grids. - `A`: field array. For vector field, the first dimension should be 3 if it's not an SVector wrapper. - `order::Int=1`: order of interpolation in [1,2,3]. - - `bc::Int=1`: type of boundary conditions, 1 -> NaN, 2 -> periodic, 3 -> Flat. + - `bc::Int=3`: type of boundary conditions, 1 -> NaN (not yet native; requires FastInterpolations support), 2 -> periodic, 3 -> Flat. # Notes The input array `A` may be modified in-place for memory optimization. """ function build_interpolator( ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 3 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) @@ -218,18 +152,21 @@ end function build_interpolator( ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 3 ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." end itp = _fastinterp((gridx, gridy, gridz), A, order, bc) - return FieldInterpolator(itp, (gridx, gridy, gridz), bc) + # Return field value at a given location. + return FieldInterpolator(itp) end function build_interpolator( ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 3 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) @@ -238,7 +175,8 @@ end function build_interpolator( ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 3 ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." @@ -248,7 +186,7 @@ function build_interpolator( end itp = _fastinterp((gridx, gridy, gridz), A, order, bc) - return FieldInterpolator(itp, (gridx, gridy, gridz), bc) + return FieldInterpolator(itp) end function build_interpolator( @@ -267,20 +205,24 @@ function build_interpolator( if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." end - + # Detect if uniform grid (old Spherical) or non-uniform r (old SphericalNonUniformR) + # We check if gridr is an AbstractRange, e.g. Base.LogRange is an AbstractRange but not uniform! is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) if is_uniform_r itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) - else + else # Non-uniform R (SphericalNonUniformR behavior) gridϕ, A = _ensure_full_phi(gridϕ, A) itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) end - return SphericalFieldInterpolator(itp, (gridr, gridθ, gridϕ), bc) + return SphericalFieldInterpolator(itp) end -function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 2) +function build_interpolator( + ::Type{<:CartesianGrid}, A, + gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 3 + ) if eltype(A) <: SVector @assert ndims(A) == 2 "Inconsistent 2D force field and grid! Expected 2D array of SVectors." As = A @@ -290,10 +232,13 @@ function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, g end itp = _fastinterp((gridx, gridy), As, order, bc) - return FieldInterpolator2D(itp, (gridx, gridy), bc) + return FieldInterpolator2D(itp) end -function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, order::Int = 1, bc::Int = 3; dir = 1) +function build_interpolator( + ::Type{<:CartesianGrid}, A, gridx::AbstractVector, + order::Int = 1, bc::Int = 3; dir = 1 + ) if eltype(A) <: SVector @assert ndims(A) == 1 "Inconsistent 1D force field and grid! Expected 1D array of SVectors." As = A @@ -304,7 +249,7 @@ function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, o itp = _fastinterp((gridx,), As, order, bc) - return FieldInterpolator1D(itp, gridx, bc, dir) + return FieldInterpolator1D(itp, dir) end function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} @@ -340,7 +285,7 @@ function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} if needs_2pi if needs_0 selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(new_A, phi_dim, 1) - else + else # needs 2π only selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(A, phi_dim, 1) end end diff --git a/src/utility/interpolation.jl b/src/utility/interpolation.jl index 90dff4f66..956e70e88 100644 --- a/src/utility/interpolation.jl +++ b/src/utility/interpolation.jl @@ -1,7 +1,5 @@ # Field interpolations. -@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) - """ AbstractFieldInterpolator @@ -60,7 +58,7 @@ struct FieldInterpolator1D{T} <: AbstractFieldInterpolator dir::Int end -function (fi::FieldInterpolator1D)(xu) +@inbounds function (fi::FieldInterpolator1D)(xu) return fi.itp(xu[fi.dir]) end @@ -97,6 +95,8 @@ end Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp)) +@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) + """ build_interpolator(gridtype, A, grids..., order::Int=1, bc::Int=1) build_interpolator(A, grids..., order::Int=1, bc::Int=1) @@ -115,7 +115,8 @@ The input array `A` may be modified in-place for memory optimization. """ function build_interpolator( ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 1 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) @@ -124,7 +125,8 @@ end function build_interpolator( ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 1 ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." @@ -138,7 +140,8 @@ end function build_interpolator( ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 1 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) @@ -147,7 +150,8 @@ end function build_interpolator( ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 1 ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." @@ -205,7 +209,10 @@ function build_interpolator( return SphericalFieldInterpolator(itp) end -function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 2) +function build_interpolator( + ::Type{<:CartesianGrid}, A, + gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 2 + ) if eltype(A) <: SVector @assert ndims(A) == 2 "Inconsistent 2D force field and grid! Expected 2D array of SVectors." As = A @@ -220,7 +227,10 @@ function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, g return FieldInterpolator2D(interp) end -function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, order::Int = 1, bc::Int = 3; dir = 1) +function build_interpolator( + ::Type{<:CartesianGrid}, A, gridx::AbstractVector, + order::Int = 1, bc::Int = 3; dir = 1 + ) if eltype(A) <: SVector @assert ndims(A) == 1 "Inconsistent 1D force field and grid! Expected 1D array of SVectors." As = A diff --git a/test/test_utility.jl b/test/test_utility.jl index 27b000f81..f7b1baad9 100644 --- a/test/test_utility.jl +++ b/test/test_utility.jl @@ -49,20 +49,23 @@ import TestParticle as TP Float32(i + j + k) for i in eachindex(x), j in eachindex(y), k in eachindex(z) ] - nfunc11 = TP.build_interpolator(n, x, y, z) - @test nfunc11(SA[9, 0, 0]) ≈ 11.85f0 + # TODO: re-enable when FastInterpolations.jl natively supports bc=1 (NaN outside domain). + # nfunc11 = TP.build_interpolator(n, x, y, z) + # @test nfunc11(SA[9, 0, 0]) ≈ 11.85f0 nfunc12 = TP.build_interpolator(n, x, y, z, 1, 2) @test nfunc12(SA[20, 0, 0]) ≈ 10.5f0 nfunc13 = TP.build_interpolator(n, x, y, z, 1, 3) @test nfunc13(SA[20, 0, 0]) ≈ 12.0f0 - nfunc21 = TP.build_interpolator(n, x, y, z, 2) - @test nfunc21(SA[9, 0, 0]) ≈ 11.85f0 + # TODO: re-enable when FastInterpolations.jl natively supports bc=1 (NaN outside domain). + # nfunc21 = TP.build_interpolator(n, x, y, z, 2) + # @test nfunc21(SA[9, 0, 0]) ≈ 11.85f0 nfunc22 = TP.build_interpolator(n, x, y, z, 2, 2) @test nfunc22(SA[20, 0, 0]) ≈ 10.5f0 nfunc23 = TP.build_interpolator(n, x, y, z, 2, 3) @test nfunc23(SA[20, 0, 0]) ≈ 12.0f0 - nfunc31 = TP.build_interpolator(n, x, y, z, 3) - @test nfunc31(SA[9, 0, 0]) ≈ 11.85f0 + # TODO: re-enable when FastInterpolations.jl natively supports bc=1 (NaN outside domain). + # nfunc31 = TP.build_interpolator(n, x, y, z, 3) + # @test nfunc31(SA[9, 0, 0]) ≈ 11.85f0 nfunc32 = TP.build_interpolator(n, x, y, z, 3, 2) @test nfunc32(SA[20, 0, 0]) ≈ 10.5f0 nfunc33 = TP.build_interpolator(n, x, y, z, 3, 3) From bdd6afc2d8694deec00b2e576f8d7dba0c451b2c Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Wed, 4 Mar 2026 21:30:52 -0500 Subject: [PATCH 11/27] Simplify spherical grid support --- src/TestParticle.jl | 2 +- src/utility/fastinterpolation.jl | 78 ++++++++++---------------------- test/test_utility.jl | 2 +- 3 files changed, 27 insertions(+), 55 deletions(-) diff --git a/src/TestParticle.jl b/src/TestParticle.jl index aa3ce2519..1d42727c5 100644 --- a/src/TestParticle.jl +++ b/src/TestParticle.jl @@ -2,7 +2,7 @@ module TestParticle using LinearAlgebra: norm, ×, ⋅, diag, normalize using FastInterpolations: linear_interp, quadratic_interp, cubic_interp, constant_interp, - Extrap, NoExtrap + Extrap, NoExtrap, PeriodicBC, ZeroCurvBC using SciMLBase: AbstractODEProblem, AbstractODEFunction, AbstractODESolution, ReturnCode, BasicEnsembleAlgorithm, EnsembleThreads, EnsembleSerial, EnsembleDistributed, DEFAULT_SPECIALIZATION, ODEFunction, ODEProblem, remake, diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl index 4a2cc5d33..2f31f18d9 100644 --- a/src/utility/fastinterpolation.jl +++ b/src/utility/fastinterpolation.jl @@ -109,16 +109,20 @@ function _get_extrap_mode(bc) end -function _fastinterp(grids, A, order, bc) +function _fastinterp(grids, A, order, bc; spline_bc = nothing) extrap_mode = _get_extrap_mode(bc) if order == 1 return linear_interp(grids, A; extrap = extrap_mode) elseif order == 2 - return quadratic_interp(grids, A; extrap = extrap_mode) + kwargs = spline_bc !== nothing ? (; extrap = extrap_mode, bc = spline_bc) : (; extrap = extrap_mode) + return quadratic_interp(grids, A; kwargs...) elseif order == 3 - return cubic_interp(grids, A; extrap = extrap_mode) - else + kwargs = spline_bc !== nothing ? (; extrap = extrap_mode, bc = spline_bc) : (; extrap = extrap_mode) + return cubic_interp(grids, A; kwargs...) + elseif order == 0 return constant_interp(grids, A; extrap = extrap_mode) + else + throw(ArgumentError("Unsupported interpolation order! Expected order in [1, 2, 3].")) end end @@ -205,16 +209,24 @@ function build_interpolator( if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." end - # Detect if uniform grid (old Spherical) or non-uniform r (old SphericalNonUniformR) - # We check if gridr is an AbstractRange, e.g. Base.LogRange is an AbstractRange but not uniform! - is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) - - if is_uniform_r - itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) - else # Non-uniform R (SphericalNonUniformR behavior) - gridϕ, A = _ensure_full_phi(gridϕ, A) - itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) + r_min, r_max = extrema(gridr) + θ_min, θ_max = extrema(gridθ) + ϕ_min, ϕ_max = extrema(gridϕ) + + @assert r_min > 0 "r must be strictly positive." + @assert θ_min >= 0 && θ_max <= π "θ must be within [0, π]." + @assert ϕ_min >= 0 && ϕ_max <= 2π "ϕ must be within [0, 2π]." + + has_0 = isapprox(ϕ_min, 0, atol = 1.0e-5) + has_2pi = isapprox(ϕ_max, 2π, atol = 1.0e-5) + if has_0 && has_2pi + phi_bc = PeriodicBC(endpoint = :inclusive) + else + phi_bc = PeriodicBC(endpoint = :exclusive, period = 2π) end + #TODO switch to constant extrapolation later + spline_bc = (ZeroCurvBC(), ZeroCurvBC(), phi_bc) + itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc; spline_bc) return SphericalFieldInterpolator(itp) end @@ -252,46 +264,6 @@ function build_interpolator( return FieldInterpolator1D(itp, dir) end -function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} - min_phi, max_phi = extrema(gridϕ) - needs_0 = !isapprox(min_phi, 0, atol = 1.0e-5) - needs_2pi = !isapprox(max_phi, 2π, atol = 1.0e-5) - - if !needs_0 && !needs_2pi - return gridϕ, A - end - - new_grid_vec = collect(gridϕ) - if needs_0 - pushfirst!(new_grid_vec, 0.0) - end - if needs_2pi - push!(new_grid_vec, 2π) - end - - phi_dim = N - new_A = Array{T, N}(undef, (size(A)[1:(end - 1)]..., length(new_grid_vec))) - - start_idx = needs_0 ? 2 : 1 - end_idx = start_idx + length(gridϕ) - 1 - - selectdim(new_A, phi_dim, start_idx:end_idx) .= A - - if needs_0 - src_idx_for_0 = size(A, phi_dim) - selectdim(new_A, phi_dim, 1) .= selectdim(A, phi_dim, src_idx_for_0) - end - - if needs_2pi - if needs_0 - selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(new_A, phi_dim, 1) - else # needs 2π only - selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(A, phi_dim, 1) - end - end - - return new_grid_vec, new_A -end # Time-dependent field interpolation. diff --git a/test/test_utility.jl b/test/test_utility.jl index f7b1baad9..10dfde323 100644 --- a/test/test_utility.jl +++ b/test/test_utility.jl @@ -73,7 +73,7 @@ import TestParticle as TP end begin # spherical interpolation - r = range(0, 10, length = 11) + r = range(1.0, 10.0, length = 11) θ = range(0, π, length = 11) ϕ = range(0, 2π, length = 11) # Vector field From b60523ff213f977a4c59c0c42d570b2d1c3ca227 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Wed, 4 Mar 2026 21:39:00 -0500 Subject: [PATCH 12/27] Update interpolation.jl --- src/utility/interpolation.jl | 69 ++++++++++-------------------------- 1 file changed, 18 insertions(+), 51 deletions(-) diff --git a/src/utility/interpolation.jl b/src/utility/interpolation.jl index 956e70e88..dba80e405 100644 --- a/src/utility/interpolation.jl +++ b/src/utility/interpolation.jl @@ -193,19 +193,26 @@ function build_interpolator( if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." end - # Detect if uniform grid (old Spherical) or non-uniform r (old SphericalNonUniformR) - # We check if gridr is an AbstractRange, e.g. Base.LogRange is an AbstractRange but not uniform! - is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) - - if is_uniform_r - itp_unscaled = _get_interp_object(StructuredGrid, A, order, bc) - itp = scale(itp_unscaled, gridr, gridθ, gridϕ) - else # Non-uniform R (SphericalNonUniformR behavior) - bctype = (Flat(), Flat(), Periodic()) - gridϕ, A = _ensure_full_phi(gridϕ, A) - itp = extrapolate(interpolate!((gridr, gridθ, gridϕ), A, Gridded(Linear())), bctype) + r_min, r_max = extrema(gridr) + θ_min, θ_max = extrema(gridθ) + ϕ_min, ϕ_max = extrema(gridϕ) + + @assert r_min > 0 "r must be strictly positive." + @assert θ_min >= 0 && θ_max <= π "θ must be within [0, π]." + @assert ϕ_min >= 0 && ϕ_max <= 2π "ϕ must be within [0, 2π]." + + has_0 = isapprox(ϕ_min, 0, atol = 1.0e-5) + has_2pi = isapprox(ϕ_max, 2π, atol = 1.0e-5) + + if has_0 && has_2pi + phi_bc = Periodic(OnGrid()) + else + phi_bc = Periodic(OnCell()) end + bctype = (Flat(), Flat(), phi_bc) + itp = extrapolate(interpolate!((gridr, gridθ, gridϕ), A, Gridded(Linear())), bctype) + return SphericalFieldInterpolator(itp) end @@ -306,46 +313,6 @@ function _get_interp_object(::Type{<:StructuredGrid}, A, order::Int, bc::Int) return extrapolate(interpolate(A, itp_type), bctype) end -function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} - min_phi, max_phi = extrema(gridϕ) - needs_0 = !isapprox(min_phi, 0, atol = 1.0e-5) - needs_2pi = !isapprox(max_phi, 2π, atol = 1.0e-5) - - if !needs_0 && !needs_2pi - return gridϕ, A - end - - new_grid_vec = collect(gridϕ) - if needs_0 - pushfirst!(new_grid_vec, 0.0) - end - if needs_2pi - push!(new_grid_vec, 2π) - end - - phi_dim = N - new_A = Array{T, N}(undef, (size(A)[1:(end - 1)]..., length(new_grid_vec))) - - start_idx = needs_0 ? 2 : 1 - end_idx = start_idx + length(gridϕ) - 1 - - selectdim(new_A, phi_dim, start_idx:end_idx) .= A - - if needs_0 - src_idx_for_0 = size(A, phi_dim) - selectdim(new_A, phi_dim, 1) .= selectdim(A, phi_dim, src_idx_for_0) - end - - if needs_2pi - if needs_0 - selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(new_A, phi_dim, 1) - else # needs 2π only - selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(A, phi_dim, 1) - end - end - - return new_grid_vec, new_A -end # Time-dependent field interpolation. From 980414b41f063b67156225ad48ba42da2d44bf77 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Thu, 5 Mar 2026 00:03:43 -0500 Subject: [PATCH 13/27] refactor: improve spherical interpolation (#485) * refactor: improve boundary handling * Support higher order interpolation for StructuredGrid --- src/TestParticle.jl | 2 +- src/utility/interpolation.jl | 133 +++++++++++++++-------------------- test/test_utility.jl | 13 +++- 3 files changed, 69 insertions(+), 79 deletions(-) diff --git a/src/TestParticle.jl b/src/TestParticle.jl index f5c1bed22..3b37f37fe 100644 --- a/src/TestParticle.jl +++ b/src/TestParticle.jl @@ -3,7 +3,7 @@ module TestParticle using LinearAlgebra: norm, ×, ⋅, diag, normalize using Interpolations: interpolate, interpolate!, extrapolate, scale, BSpline, Linear, Quadratic, Cubic, - Line, OnCell, Periodic, Flat, Gridded + Line, OnCell, OnGrid, Periodic, Flat, Gridded using SciMLBase: AbstractODEProblem, AbstractODEFunction, AbstractODESolution, ReturnCode, BasicEnsembleAlgorithm, EnsembleThreads, EnsembleSerial, EnsembleDistributed, DEFAULT_SPECIALIZATION, ODEFunction, ODEProblem, remake, diff --git a/src/utility/interpolation.jl b/src/utility/interpolation.jl index 90dff4f66..cbec3f8e3 100644 --- a/src/utility/interpolation.jl +++ b/src/utility/interpolation.jl @@ -1,7 +1,5 @@ # Field interpolations. -@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) - """ AbstractFieldInterpolator @@ -60,7 +58,7 @@ struct FieldInterpolator1D{T} <: AbstractFieldInterpolator dir::Int end -function (fi::FieldInterpolator1D)(xu) +@inbounds function (fi::FieldInterpolator1D)(xu) return fi.itp(xu[fi.dir]) end @@ -97,6 +95,8 @@ end Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp)) +@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) + """ build_interpolator(gridtype, A, grids..., order::Int=1, bc::Int=1) build_interpolator(A, grids..., order::Int=1, bc::Int=1) @@ -115,7 +115,8 @@ The input array `A` may be modified in-place for memory optimization. """ function build_interpolator( ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 1 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) @@ -124,7 +125,8 @@ end function build_interpolator( ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 1 ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." @@ -138,7 +140,8 @@ end function build_interpolator( ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 1 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) @@ -147,7 +150,8 @@ end function build_interpolator( ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 1 ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." @@ -175,7 +179,7 @@ end function build_interpolator( ::Type{<:StructuredGrid}, A::AbstractArray{T, 4}, - gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 + gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) @@ -184,28 +188,56 @@ end function build_interpolator( ::Type{<:StructuredGrid}, A::AbstractArray{T, 3}, - gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 + gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." end - # Detect if uniform grid (old Spherical) or non-uniform r (old SphericalNonUniformR) - # We check if gridr is an AbstractRange, e.g. Base.LogRange is an AbstractRange but not uniform! - is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) - - if is_uniform_r - itp_unscaled = _get_interp_object(StructuredGrid, A, order, bc) - itp = scale(itp_unscaled, gridr, gridθ, gridϕ) - else # Non-uniform R (SphericalNonUniformR behavior) - bctype = (Flat(), Flat(), Periodic()) - gridϕ, A = _ensure_full_phi(gridϕ, A) + r_min, r_max = extrema(gridr) + θ_min, θ_max = extrema(gridθ) + ϕ_min, ϕ_max = extrema(gridϕ) + + @assert r_min >= 0 "r must be non-negative." + @assert θ_min >= 0 && θ_max <= π "θ must be within [0, π]." + @assert ϕ_min >= 0 && ϕ_max <= 2π "ϕ must be within [0, 2π]." + + has_0 = isapprox(ϕ_min, 0, atol = 1.0e-5) + has_2pi = isapprox(ϕ_max, 2π, atol = 1.0e-5) + + ϕ_bc = if has_0 && has_2pi + Periodic(OnGrid()) + else + Periodic(OnCell()) + end + + bctype = (Flat(), Flat(), ϕ_bc) + + if order == 1 itp = extrapolate(interpolate!((gridr, gridθ, gridϕ), A, Gridded(Linear())), bctype) + else + interp_type = if order == 2 + Quadratic + elseif order == 3 + Cubic + else + throw(ArgumentError("Unsupported interpolation order!")) + end + itp_type = ( + BSpline(interp_type(Flat(OnCell()))), + BSpline(interp_type(Flat(OnCell()))), + BSpline(interp_type(ϕ_bc)), + ) + itp_obj = extrapolate(interpolate(A, itp_type), bctype) + itp = scale(itp_obj, gridr, gridθ, gridϕ) end return SphericalFieldInterpolator(itp) end -function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 2) +function build_interpolator( + ::Type{<:CartesianGrid}, A, + gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 2 + ) if eltype(A) <: SVector @assert ndims(A) == 2 "Inconsistent 2D force field and grid! Expected 2D array of SVectors." As = A @@ -220,7 +252,10 @@ function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, g return FieldInterpolator2D(interp) end -function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, order::Int = 1, bc::Int = 3; dir = 1) +function build_interpolator( + ::Type{<:CartesianGrid}, A, gridx::AbstractVector, + order::Int = 1, bc::Int = 3; dir = 1 + ) if eltype(A) <: SVector @assert ndims(A) == 1 "Inconsistent 1D force field and grid! Expected 1D array of SVectors." As = A @@ -280,62 +315,6 @@ function _get_interp_object(A, order::Int, bc::Int) return extrapolate(interpolate(A, bspline), bctype) end -function _get_interp_object(::Type{<:StructuredGrid}, A, order::Int, bc::Int) - bspline_r = _get_bspline(order, false) - bspline_θ = _get_bspline(order, false) - bspline_ϕ = _get_bspline(order, true) - - itp_type = (bspline_r, bspline_θ, bspline_ϕ) - - bctype = if eltype(A) <: SVector - SVector{3, eltype(eltype(A))}(NaN, NaN, NaN) - else - eltype(A)(NaN) - end - - return extrapolate(interpolate(A, itp_type), bctype) -end - -function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} - min_phi, max_phi = extrema(gridϕ) - needs_0 = !isapprox(min_phi, 0, atol = 1.0e-5) - needs_2pi = !isapprox(max_phi, 2π, atol = 1.0e-5) - - if !needs_0 && !needs_2pi - return gridϕ, A - end - - new_grid_vec = collect(gridϕ) - if needs_0 - pushfirst!(new_grid_vec, 0.0) - end - if needs_2pi - push!(new_grid_vec, 2π) - end - - phi_dim = N - new_A = Array{T, N}(undef, (size(A)[1:(end - 1)]..., length(new_grid_vec))) - - start_idx = needs_0 ? 2 : 1 - end_idx = start_idx + length(gridϕ) - 1 - - selectdim(new_A, phi_dim, start_idx:end_idx) .= A - - if needs_0 - src_idx_for_0 = size(A, phi_dim) - selectdim(new_A, phi_dim, 1) .= selectdim(A, phi_dim, src_idx_for_0) - end - - if needs_2pi - if needs_0 - selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(new_A, phi_dim, 1) - else # needs 2π only - selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(A, phi_dim, 1) - end - end - - return new_grid_vec, new_A -end # Time-dependent field interpolation. diff --git a/test/test_utility.jl b/test/test_utility.jl index 5c3bba260..c45a68002 100644 --- a/test/test_utility.jl +++ b/test/test_utility.jl @@ -70,7 +70,7 @@ import TestParticle as TP end begin # spherical interpolation - r = range(0, 10, length = 11) + r = range(0.1, 10, length = 11) θ = range(0, π, length = 11) ϕ = range(0, 2π, length = 11) # Vector field @@ -83,6 +83,17 @@ import TestParticle as TP Afunc = TP.build_interpolator(TP.StructuredGrid, A, r, θ, ϕ) @test Afunc(SA[1, 1, 1]) == 1.0 @test Afunc(SA[0, 0, 0]) == 1.0 + + # High order spherical interpolation + Bfunc2 = TP.build_interpolator(TP.StructuredGrid, B, r, θ, ϕ, 2) + @test Bfunc2(SA[1, 1, 1]) ≈ [0.57735, 0.57735, 0.57735] atol = 1.0e-5 + Afunc2 = TP.build_interpolator(TP.StructuredGrid, A, r, θ, ϕ, 2) + @test Afunc2(SA[1, 1, 1]) ≈ 1.0 + + Bfunc3 = TP.build_interpolator(TP.StructuredGrid, B, r, θ, ϕ, 3) + @test Bfunc3(SA[1, 1, 1]) ≈ [0.57735, 0.57735, 0.57735] atol = 1.0e-5 + Afunc3 = TP.build_interpolator(TP.StructuredGrid, A, r, θ, ϕ, 3) + @test Afunc3(SA[1, 1, 1]) ≈ 1.0 end begin # non-uniform spherical interpolation From 33e74f46198c68a302b95b84e4196bf09d0d5f9f Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Sun, 1 Mar 2026 10:24:15 -0500 Subject: [PATCH 14/27] refactor: substitute Interpolations.jl with FastInterpolations.jl --- Project.toml | 4 +- src/TestParticle.jl | 5 +- src/utility/fastinterpolation.jl | 434 +++++++++++++++++++++++++++++++ 3 files changed, 437 insertions(+), 6 deletions(-) create mode 100644 src/utility/fastinterpolation.jl diff --git a/Project.toml b/Project.toml index 2b7d7b7f9..4c535fdb7 100644 --- a/Project.toml +++ b/Project.toml @@ -11,8 +11,8 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +FastInterpolations = "9ea80cae-fc13-4c00-8066-6eaedb12f34b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa" @@ -38,8 +38,8 @@ Adapt = "4.4" ChunkSplitters = "3" DiffResults = "1" Distributed = "1" +FastInterpolations = "0.3.0" ForwardDiff = "1" -Interpolations = "0.15, 0.16" KernelAbstractions = "0.9" LinearAlgebra = "1" Meshes = "0.55, 0.56" diff --git a/src/TestParticle.jl b/src/TestParticle.jl index 3b37f37fe..9ca9970ba 100644 --- a/src/TestParticle.jl +++ b/src/TestParticle.jl @@ -1,9 +1,6 @@ module TestParticle using LinearAlgebra: norm, ×, ⋅, diag, normalize -using Interpolations: interpolate, interpolate!, extrapolate, scale, BSpline, Linear, - Quadratic, Cubic, - Line, OnCell, OnGrid, Periodic, Flat, Gridded using SciMLBase: AbstractODEProblem, AbstractODEFunction, AbstractODESolution, ReturnCode, BasicEnsembleAlgorithm, EnsembleThreads, EnsembleSerial, EnsembleDistributed, DEFAULT_SPECIALIZATION, ODEFunction, ODEProblem, remake, @@ -46,7 +43,7 @@ export EnsembleSerial, EnsembleThreads, EnsembleDistributed, remake include("types.jl") include("utility/utility.jl") -include("utility/interpolation.jl") +include("utility/fastinterpolation.jl") include("sampler.jl") include("prepare.jl") include("gc.jl") diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl new file mode 100644 index 000000000..5f642a046 --- /dev/null +++ b/src/utility/fastinterpolation.jl @@ -0,0 +1,434 @@ +# Field interpolations using FastInterpolations.jl. + +using FastInterpolations +using StaticArrays +using Adapt + +import FastInterpolations: _value_type +# Promote rule to construct the interpolant object natively for SVector. +_value_type(::Type{SVector{N, T}}, ::Type{Tg}) where {N, T, Tg <: AbstractFloat} = SVector{N, promote_type(T, Tg)} + +@inline getinterp(A, grid1, args...) = getinterp(CartesianGrid, A, grid1, args...) + +""" + AbstractFieldInterpolator + +Abstract type for all field interpolators. +""" +abstract type AbstractFieldInterpolator <: Function end + +""" + FieldInterpolator{T, G} + +A callable struct that wraps a 3D interpolation object and its grid. +""" +struct FieldInterpolator{T, G} <: AbstractFieldInterpolator + itp::T + grid::G + bc::Int +end + +_in_bounds(x, gridx) = first(gridx) <= x <= last(gridx) + +function (fi::FieldInterpolator)(xu) + if fi.bc == 1 + if !(_in_bounds(xu[1], fi.grid[1]) && _in_bounds(xu[2], fi.grid[2]) && _in_bounds(xu[3], fi.grid[3])) + T_val = typeof(fi.itp((xu[1], xu[2], xu[3]))) + if T_val <: SVector + return fill(eltype(T_val)(NaN), length(T_val)) + else + return T_val(NaN) + end + end + end + return fi.itp((xu[1], xu[2], xu[3])) +end + +function (fi::FieldInterpolator)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::FieldInterpolator) = FieldInterpolator(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) + +""" + FieldInterpolator2D{T, G} + +A callable struct that wraps a 2D interpolation object. +""" +struct FieldInterpolator2D{T, G} <: AbstractFieldInterpolator + itp::T + grid::G + bc::Int +end + +function (fi::FieldInterpolator2D)(xu) + if fi.bc == 1 + if !(_in_bounds(xu[1], fi.grid[1]) && _in_bounds(xu[2], fi.grid[2])) + T_val = typeof(fi.itp((xu[1], xu[2]))) + if T_val <: SVector + return fill(eltype(T_val)(NaN), length(T_val)) + else + return T_val(NaN) + end + end + end + return fi.itp((xu[1], xu[2])) +end + +function (fi::FieldInterpolator2D)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::FieldInterpolator2D) = FieldInterpolator2D(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) + +""" + FieldInterpolator1D{T, G} + +A callable struct that wraps a 1D interpolation object. +""" +struct FieldInterpolator1D{T, G} <: AbstractFieldInterpolator + itp::T + grid::G + bc::Int + dir::Int +end + +function (fi::FieldInterpolator1D)(xu) + if fi.bc == 1 + if !_in_bounds(xu[fi.dir], fi.grid) + T_val = typeof(fi.itp((xu[fi.dir],))) + if T_val <: SVector + return fill(eltype(T_val)(NaN), length(T_val)) + else + return T_val(NaN) + end + end + end + return fi.itp((xu[fi.dir],)) +end + +function (fi::FieldInterpolator1D)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::FieldInterpolator1D) = FieldInterpolator1D(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc, fi.dir) + +""" + SphericalFieldInterpolator{T, G} + +A callable struct for spherical grid interpolation. +""" +struct SphericalFieldInterpolator{T, G} <: AbstractFieldInterpolator + itp::T + grid::G + bc::Int +end + +function (fi::SphericalFieldInterpolator)(xu) + r_val, θ_val, ϕ_val = cart2sph(xu) + + if fi.bc == 1 + if !(_in_bounds(r_val, fi.grid[1]) && _in_bounds(θ_val, fi.grid[2]) && _in_bounds(ϕ_val, fi.grid[3])) + res_val = fi.itp((r_val, θ_val, ϕ_val)) + if typeof(res_val) <: SVector || length(res_val) > 1 + return fill(eltype(res_val)(NaN), length(res_val)) + else + return typeof(res_val)(NaN) + end + end + end + + res = fi.itp((r_val, θ_val, ϕ_val)) + if typeof(res) <: SVector || length(res) > 1 + Br, Bθ, Bϕ = res + return sph_to_cart_vector(Br, Bθ, Bϕ, θ_val, ϕ_val) + else + return res + end +end + +function (fi::SphericalFieldInterpolator)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) + +function getinterp_scalar(A, grid1, grid2, grid3, args...) + return getinterp_scalar(CartesianGrid, A, grid1, grid2, grid3, args...) +end + +function _get_extrap_mode(bc) + if bc == 2 + return Extrap(:wrap) + elseif bc == 3 + return Extrap(:constant) + else + # For bc == 1, we handle NaN in wrapper and use NoExtrap() for inner to avoid errors + return NoExtrap() + end +end + +struct ComponentInterpolator{T1, T2, T3} + itp1::T1 + itp2::T2 + itp3::T3 +end + +function (ci::ComponentInterpolator)(x) + return SA[ci.itp1(x), ci.itp2(x), ci.itp3(x)] +end + +function _fastinterp(grids, A, order, bc) + extrap_mode = _get_extrap_mode(bc) + if order == 1 + return linear_interp(grids, A; extrap = extrap_mode) + elseif order == 2 + return quadratic_interp(grids, A; extrap = extrap_mode) + elseif order == 3 + if eltype(A) <: SVector{3} + A1 = [v[1] for v in A] + A2 = [v[2] for v in A] + A3 = [v[3] for v in A] + itp1 = cubic_interp(grids, A1; extrap = extrap_mode) + itp2 = cubic_interp(grids, A2; extrap = extrap_mode) + itp3 = cubic_interp(grids, A3; extrap = extrap_mode) + return ComponentInterpolator(itp1, itp2, itp3) + end + return cubic_interp(grids, A; extrap = extrap_mode) + else + return constant_interp(grids, A; extrap = extrap_mode) + end +end + +function getinterp( + ::Type{<:CartesianGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) + if eltype(A) <: SVector + @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." + else + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + end + return get_interpolator(CartesianGrid, A, gridx, gridy, gridz, order, bc) +end + +function getinterp( + ::Type{<:RectilinearGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) + if eltype(A) <: SVector + @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." + else + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + end + return get_interpolator(RectilinearGrid, A, gridx, gridy, gridz, order, bc) +end + +function getinterp( + ::Type{<:StructuredGrid}, A, gridr, gridθ, gridϕ, + order::Int = 1, bc::Int = 3 + ) + if eltype(A) <: SVector + @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." + else + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + end + return get_interpolator(StructuredGrid, A, gridr, gridθ, gridϕ, order, bc) +end + +function getinterp(::Type{<:CartesianGrid}, A, gridx, gridy, order::Int = 1, bc::Int = 2) + if eltype(A) <: SVector + @assert ndims(A) == 2 "Inconsistent 2D force field and grid! Expected 2D array of SVectors." + As = A + else + @assert size(A, 1) == 3 && ndims(A) == 3 "Inconsistent 2D force field and grid!" + As = reinterpret(reshape, SVector{3, eltype(A)}, A) + end + + itp = _fastinterp((gridx, gridy), As, order, bc) + return FieldInterpolator2D(itp, (gridx, gridy), bc) +end + +function get_interpolator( + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, + gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) where {T} + As = reinterpret(reshape, SVector{3, T}, A) + return get_interpolator(RectilinearGrid, As, gridx, gridy, gridz, order, bc) +end + +function get_interpolator( + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, + gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) where {T} + if order != 1 + throw(ArgumentError("RectilinearGrid (CartesianNonUniform) only supports order=1 (Linear) interpolation.")) + end + + itp = _fastinterp((gridx, gridy, gridz), A, order, bc) + return FieldInterpolator(itp, (gridx, gridy, gridz), bc) +end + +function getinterp(::Type{<:CartesianGrid}, A, gridx, order::Int = 1, bc::Int = 3; dir = 1) + if eltype(A) <: SVector + @assert ndims(A) == 1 "Inconsistent 1D force field and grid! Expected 1D array of SVectors." + As = A + else + @assert size(A, 1) == 3 && ndims(A) == 2 "Inconsistent 1D force field and grid!" + As = reinterpret(reshape, SVector{3, eltype(A)}, A) + end + + itp = _fastinterp((gridx,), As, order, bc) + return FieldInterpolator1D(itp, gridx, bc, dir) +end + +function getinterp_scalar( + ::Type{<:CartesianGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) + return get_interpolator(CartesianGrid, A, gridx, gridy, gridz, order, bc) +end + +function getinterp_scalar( + ::Type{<:RectilinearGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) + return get_interpolator(RectilinearGrid, A, gridx, gridy, gridz, order, bc) +end + +function getinterp_scalar( + ::Type{<:StructuredGrid}, A, + gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 + ) + return get_interpolator(StructuredGrid, A, gridr, gridθ, gridϕ, order, bc) +end + +function get_interpolator( + ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, + gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) where {T} + As = reinterpret(reshape, SVector{3, T}, A) + return get_interpolator(CartesianGrid, As, gridx, gridy, gridz, order, bc) +end + +function get_interpolator( + ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, + gridx, gridy, gridz, order::Int = 1, bc::Int = 1 + ) where {T} + itp = _fastinterp((gridx, gridy, gridz), A, order, bc) + return FieldInterpolator(itp, (gridx, gridy, gridz), bc) +end + +function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} + min_phi, max_phi = extrema(gridϕ) + needs_0 = !isapprox(min_phi, 0, atol = 1.0e-5) + needs_2pi = !isapprox(max_phi, 2π, atol = 1.0e-5) + + if !needs_0 && !needs_2pi + return gridϕ, A + end + + new_grid_vec = collect(gridϕ) + if needs_0 + pushfirst!(new_grid_vec, 0.0) + end + if needs_2pi + push!(new_grid_vec, 2π) + end + + phi_dim = N + new_A = Array{T, N}(undef, (size(A)[1:(end - 1)]..., length(new_grid_vec))) + + start_idx = needs_0 ? 2 : 1 + end_idx = start_idx + length(gridϕ) - 1 + + selectdim(new_A, phi_dim, start_idx:end_idx) .= A + + if needs_0 + src_idx_for_0 = size(A, phi_dim) + selectdim(new_A, phi_dim, 1) .= selectdim(A, phi_dim, src_idx_for_0) + end + + if needs_2pi + if needs_0 + selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(new_A, phi_dim, 1) + else + selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(A, phi_dim, 1) + end + end + + return new_grid_vec, new_A +end + +function get_interpolator( + ::Type{<:StructuredGrid}, + A::AbstractArray{T, 4}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 + ) where {T} + As = reinterpret(reshape, SVector{3, T}, A) + return get_interpolator(StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) +end + +function get_interpolator( + ::Type{<:StructuredGrid}, + A::AbstractArray{T, 3}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 + ) where {T} + is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) + + if is_uniform_r + itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) + else + gridϕ, A = _ensure_full_phi(gridϕ, A) + itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) + end + + return SphericalFieldInterpolator(itp, (gridr, gridθ, gridϕ), bc) +end + +# Time-dependent field interpolation. + +struct LazyTimeInterpolator{T, F, L} <: Function + times::Vector{T} + loader::L + buffer::Dict{Int, F} + lock::ReentrantLock +end + +function LazyTimeInterpolator(times::AbstractVector, loader::Function) + f1 = loader(1) + return _LazyTimeInterpolator(times, loader, f1) +end + +function _LazyTimeInterpolator(times::AbstractVector, loader::Function, f1::F) where {F} + buffer = Dict{Int, F}(1 => f1) + lock = ReentrantLock() + return LazyTimeInterpolator{eltype(times), F, typeof(loader)}( + times, loader, buffer, lock + ) +end + +function (itp::LazyTimeInterpolator)(x, t) + idx = searchsortedlast(itp.times, t) + + if idx == 0 + return _get_field!(itp, 1)(x) + elseif idx >= length(itp.times) + return _get_field!(itp, length(itp.times))(x) + end + + t1 = itp.times[idx] + t2 = itp.times[idx + 1] + + w = (t - t1) / (t2 - t1) + + f1 = _get_field!(itp, idx) + f2 = _get_field!(itp, idx + 1) + + return (1 - w) * f1(x) + w * f2(x) +end + +function _get_field!(itp::LazyTimeInterpolator, idx::Int) + return lock(itp.lock) do + if !haskey(itp.buffer, idx) + filter!(p -> abs(p.first - idx) <= 1, itp.buffer) + field = itp.loader(idx) + itp.buffer[idx] = field + end + return itp.buffer[idx] + end +end From b2d3c0bf4cc1692385f869f3ca1bc4894df0527a Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Sun, 1 Mar 2026 22:50:15 -0500 Subject: [PATCH 15/27] Try to fix the Float32 vs Float64 bug --- src/utility/fastinterpolation.jl | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl index 5f642a046..a300e1bd3 100644 --- a/src/utility/fastinterpolation.jl +++ b/src/utility/fastinterpolation.jl @@ -178,25 +178,40 @@ function (ci::ComponentInterpolator)(x) return SA[ci.itp1(x), ci.itp2(x), ci.itp3(x)] end +function _match_grid_type(g::AbstractRange, ::Type{T}) where {T <: AbstractFloat} + return range(T(first(g)), T(last(g)), length = length(g)) +end +function _match_grid_type(g::AbstractVector, ::Type{T}) where {T <: AbstractFloat} + return T.(g) +end +function _match_grid_type(g::Tuple, ::Type{T}) where {T <: AbstractFloat} + return T.(g) +end + function _fastinterp(grids, A, order, bc) + T_A = eltype(A) + T_F = T_A <: SVector ? eltype(T_A) : T_A + T_F = T_F <: AbstractFloat ? T_F : Float64 + matched_grids = map(g -> _match_grid_type(g, T_F), grids) + extrap_mode = _get_extrap_mode(bc) if order == 1 - return linear_interp(grids, A; extrap = extrap_mode) + return linear_interp(matched_grids, A; extrap = extrap_mode) elseif order == 2 - return quadratic_interp(grids, A; extrap = extrap_mode) + return quadratic_interp(matched_grids, A; extrap = extrap_mode) elseif order == 3 if eltype(A) <: SVector{3} A1 = [v[1] for v in A] A2 = [v[2] for v in A] A3 = [v[3] for v in A] - itp1 = cubic_interp(grids, A1; extrap = extrap_mode) - itp2 = cubic_interp(grids, A2; extrap = extrap_mode) - itp3 = cubic_interp(grids, A3; extrap = extrap_mode) + itp1 = cubic_interp(matched_grids, A1; extrap = extrap_mode) + itp2 = cubic_interp(matched_grids, A2; extrap = extrap_mode) + itp3 = cubic_interp(matched_grids, A3; extrap = extrap_mode) return ComponentInterpolator(itp1, itp2, itp3) end - return cubic_interp(grids, A; extrap = extrap_mode) + return cubic_interp(matched_grids, A; extrap = extrap_mode) else - return constant_interp(grids, A; extrap = extrap_mode) + return constant_interp(matched_grids, A; extrap = extrap_mode) end end From 53ae9dce71a294c16adb28b5622ec0bd102d9808 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Sun, 1 Mar 2026 23:12:19 -0500 Subject: [PATCH 16/27] fix out-of-bound return type --- src/utility/fastinterpolation.jl | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl index a300e1bd3..1f8429d25 100644 --- a/src/utility/fastinterpolation.jl +++ b/src/utility/fastinterpolation.jl @@ -8,6 +8,9 @@ import FastInterpolations: _value_type # Promote rule to construct the interpolant object natively for SVector. _value_type(::Type{SVector{N, T}}, ::Type{Tg}) where {N, T, Tg <: AbstractFloat} = SVector{N, promote_type(T, Tg)} +Base.eltype(::FastInterpolations.AbstractInterpolant{Tx, Ty}) where {Tx, Ty} = Ty +Base.eltype(::FastInterpolations.AbstractInterpolantND{Tx, Ty, N}) where {Tx, Ty, N} = Ty + @inline getinterp(A, grid1, args...) = getinterp(CartesianGrid, A, grid1, args...) """ @@ -33,9 +36,9 @@ _in_bounds(x, gridx) = first(gridx) <= x <= last(gridx) function (fi::FieldInterpolator)(xu) if fi.bc == 1 if !(_in_bounds(xu[1], fi.grid[1]) && _in_bounds(xu[2], fi.grid[2]) && _in_bounds(xu[3], fi.grid[3])) - T_val = typeof(fi.itp((xu[1], xu[2], xu[3]))) + T_val = eltype(fi.itp) if T_val <: SVector - return fill(eltype(T_val)(NaN), length(T_val)) + return T_val(ntuple(_ -> NaN, Val(length(T_val)))) else return T_val(NaN) end @@ -64,9 +67,9 @@ end function (fi::FieldInterpolator2D)(xu) if fi.bc == 1 if !(_in_bounds(xu[1], fi.grid[1]) && _in_bounds(xu[2], fi.grid[2])) - T_val = typeof(fi.itp((xu[1], xu[2]))) + T_val = eltype(fi.itp) if T_val <: SVector - return fill(eltype(T_val)(NaN), length(T_val)) + return T_val(ntuple(_ -> NaN, Val(length(T_val)))) else return T_val(NaN) end @@ -96,9 +99,9 @@ end function (fi::FieldInterpolator1D)(xu) if fi.bc == 1 if !_in_bounds(xu[fi.dir], fi.grid) - T_val = typeof(fi.itp((xu[fi.dir],))) + T_val = eltype(fi.itp) if T_val <: SVector - return fill(eltype(T_val)(NaN), length(T_val)) + return T_val(ntuple(_ -> NaN, Val(length(T_val)))) else return T_val(NaN) end @@ -129,11 +132,11 @@ function (fi::SphericalFieldInterpolator)(xu) if fi.bc == 1 if !(_in_bounds(r_val, fi.grid[1]) && _in_bounds(θ_val, fi.grid[2]) && _in_bounds(ϕ_val, fi.grid[3])) - res_val = fi.itp((r_val, θ_val, ϕ_val)) - if typeof(res_val) <: SVector || length(res_val) > 1 - return fill(eltype(res_val)(NaN), length(res_val)) + T_val = eltype(fi.itp) + if T_val <: SVector + return T_val(ntuple(_ -> NaN, Val(length(T_val)))) else - return typeof(res_val)(NaN) + return T_val(NaN) end end end @@ -177,6 +180,7 @@ end function (ci::ComponentInterpolator)(x) return SA[ci.itp1(x), ci.itp2(x), ci.itp3(x)] end +Base.eltype(::ComponentInterpolator{T1, T2, T3}) where {T1, T2, T3} = SVector{3, eltype(T1)} function _match_grid_type(g::AbstractRange, ::Type{T}) where {T <: AbstractFloat} return range(T(first(g)), T(last(g)), length = length(g)) From 0d8e6cdf79a15a123d6eb3cf0e1402584b15ce43 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Wed, 4 Mar 2026 13:56:57 -0500 Subject: [PATCH 17/27] Update FastInterpolations.jl to v0.4 --- Project.toml | 2 +- src/utility/fastinterpolation.jl | 225 +++++++++++++------------------ test/runtests.jl | 4 +- test/test_utility.jl | 18 +-- 4 files changed, 108 insertions(+), 141 deletions(-) diff --git a/Project.toml b/Project.toml index 4c535fdb7..2a1451034 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,7 @@ Adapt = "4.4" ChunkSplitters = "3" DiffResults = "1" Distributed = "1" -FastInterpolations = "0.3.0" +FastInterpolations = "0.4" ForwardDiff = "1" KernelAbstractions = "0.9" LinearAlgebra = "1" diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl index 1f8429d25..f9b7a348f 100644 --- a/src/utility/fastinterpolation.jl +++ b/src/utility/fastinterpolation.jl @@ -4,14 +4,10 @@ using FastInterpolations using StaticArrays using Adapt -import FastInterpolations: _value_type -# Promote rule to construct the interpolant object natively for SVector. -_value_type(::Type{SVector{N, T}}, ::Type{Tg}) where {N, T, Tg <: AbstractFloat} = SVector{N, promote_type(T, Tg)} - Base.eltype(::FastInterpolations.AbstractInterpolant{Tx, Ty}) where {Tx, Ty} = Ty Base.eltype(::FastInterpolations.AbstractInterpolantND{Tx, Ty, N}) where {Tx, Ty, N} = Ty -@inline getinterp(A, grid1, args...) = getinterp(CartesianGrid, A, grid1, args...) +@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) """ AbstractFieldInterpolator @@ -156,10 +152,6 @@ end Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) -function getinterp_scalar(A, grid1, grid2, grid3, args...) - return getinterp_scalar(CartesianGrid, A, grid1, grid2, grid3, args...) -end - function _get_extrap_mode(bc) if bc == 2 return Extrap(:wrap) @@ -171,17 +163,6 @@ function _get_extrap_mode(bc) end end -struct ComponentInterpolator{T1, T2, T3} - itp1::T1 - itp2::T2 - itp3::T3 -end - -function (ci::ComponentInterpolator)(x) - return SA[ci.itp1(x), ci.itp2(x), ci.itp3(x)] -end -Base.eltype(::ComponentInterpolator{T1, T2, T3}) where {T1, T2, T3} = SVector{3, eltype(T1)} - function _match_grid_type(g::AbstractRange, ::Type{T}) where {T <: AbstractFloat} return range(T(first(g)), T(last(g)), length = length(g)) end @@ -204,56 +185,102 @@ function _fastinterp(grids, A, order, bc) elseif order == 2 return quadratic_interp(matched_grids, A; extrap = extrap_mode) elseif order == 3 - if eltype(A) <: SVector{3} - A1 = [v[1] for v in A] - A2 = [v[2] for v in A] - A3 = [v[3] for v in A] - itp1 = cubic_interp(matched_grids, A1; extrap = extrap_mode) - itp2 = cubic_interp(matched_grids, A2; extrap = extrap_mode) - itp3 = cubic_interp(matched_grids, A3; extrap = extrap_mode) - return ComponentInterpolator(itp1, itp2, itp3) - end return cubic_interp(matched_grids, A; extrap = extrap_mode) else return constant_interp(matched_grids, A; extrap = extrap_mode) end end -function getinterp( - ::Type{<:CartesianGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) +""" + build_interpolator(gridtype, A, grids..., order::Int=1, bc::Int=1) + build_interpolator(A, grids..., order::Int=1, bc::Int=1) + +Return a function for interpolating field array `A` on the given grids. + +# Arguments + + - `gridtype`: `CartesianGrid`, `RectilinearGrid` or `StructuredGrid`. Usually determined by the number of grids. + - `A`: field array. For vector field, the first dimension should be 3 if it's not an SVector wrapper. + - `order::Int=1`: order of interpolation in [1,2,3]. + - `bc::Int=1`: type of boundary conditions, 1 -> NaN, 2 -> periodic, 3 -> Flat. + +# Notes +The input array `A` may be modified in-place for memory optimization. +""" +function build_interpolator( + ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + ) where {T} + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + As = reinterpret(reshape, SVector{3, T}, A) + return build_interpolator(CartesianGrid, As, gridx, gridy, gridz, order, bc) +end + +function build_interpolator( + ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." - else - @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" end - return get_interpolator(CartesianGrid, A, gridx, gridy, gridz, order, bc) + itp = _fastinterp((gridx, gridy, gridz), A, order, bc) + return FieldInterpolator(itp, (gridx, gridy, gridz), bc) end -function getinterp( - ::Type{<:RectilinearGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) +function build_interpolator( + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + ) where {T} + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + As = reinterpret(reshape, SVector{3, T}, A) + return build_interpolator(RectilinearGrid, As, gridx, gridy, gridz, order, bc) +end + +function build_interpolator( + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." - else - @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" end - return get_interpolator(RectilinearGrid, A, gridx, gridy, gridz, order, bc) + if order != 1 + throw(ArgumentError("RectilinearGrid (CartesianNonUniform) only supports order=1 (Linear) interpolation.")) + end + + itp = _fastinterp((gridx, gridy, gridz), A, order, bc) + return FieldInterpolator(itp, (gridx, gridy, gridz), bc) end -function getinterp( - ::Type{<:StructuredGrid}, A, gridr, gridθ, gridϕ, - order::Int = 1, bc::Int = 3 - ) +function build_interpolator( + ::Type{<:StructuredGrid}, A::AbstractArray{T, 4}, + gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 + ) where {T} + @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + As = reinterpret(reshape, SVector{3, T}, A) + return build_interpolator(StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) +end + +function build_interpolator( + ::Type{<:StructuredGrid}, A::AbstractArray{T, 3}, + gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 + ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." + end + + is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) + + if is_uniform_r + itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) else - @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" + gridϕ, A = _ensure_full_phi(gridϕ, A) + itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) end - return get_interpolator(StructuredGrid, A, gridr, gridθ, gridϕ, order, bc) + + return SphericalFieldInterpolator(itp, (gridr, gridθ, gridϕ), bc) end -function getinterp(::Type{<:CartesianGrid}, A, gridx, gridy, order::Int = 1, bc::Int = 2) +function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 2) if eltype(A) <: SVector @assert ndims(A) == 2 "Inconsistent 2D force field and grid! Expected 2D array of SVectors." As = A @@ -266,27 +293,7 @@ function getinterp(::Type{<:CartesianGrid}, A, gridx, gridy, order::Int = 1, bc: return FieldInterpolator2D(itp, (gridx, gridy), bc) end -function get_interpolator( - ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, - gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) where {T} - As = reinterpret(reshape, SVector{3, T}, A) - return get_interpolator(RectilinearGrid, As, gridx, gridy, gridz, order, bc) -end - -function get_interpolator( - ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, - gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) where {T} - if order != 1 - throw(ArgumentError("RectilinearGrid (CartesianNonUniform) only supports order=1 (Linear) interpolation.")) - end - - itp = _fastinterp((gridx, gridy, gridz), A, order, bc) - return FieldInterpolator(itp, (gridx, gridy, gridz), bc) -end - -function getinterp(::Type{<:CartesianGrid}, A, gridx, order::Int = 1, bc::Int = 3; dir = 1) +function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, order::Int = 1, bc::Int = 3; dir = 1) if eltype(A) <: SVector @assert ndims(A) == 1 "Inconsistent 1D force field and grid! Expected 1D array of SVectors." As = A @@ -296,42 +303,8 @@ function getinterp(::Type{<:CartesianGrid}, A, gridx, order::Int = 1, bc::Int = end itp = _fastinterp((gridx,), As, order, bc) - return FieldInterpolator1D(itp, gridx, bc, dir) -end - -function getinterp_scalar( - ::Type{<:CartesianGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) - return get_interpolator(CartesianGrid, A, gridx, gridy, gridz, order, bc) -end - -function getinterp_scalar( - ::Type{<:RectilinearGrid}, A, gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) - return get_interpolator(RectilinearGrid, A, gridx, gridy, gridz, order, bc) -end -function getinterp_scalar( - ::Type{<:StructuredGrid}, A, - gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 - ) - return get_interpolator(StructuredGrid, A, gridr, gridθ, gridϕ, order, bc) -end - -function get_interpolator( - ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, - gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) where {T} - As = reinterpret(reshape, SVector{3, T}, A) - return get_interpolator(CartesianGrid, As, gridx, gridy, gridz, order, bc) -end - -function get_interpolator( - ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, - gridx, gridy, gridz, order::Int = 1, bc::Int = 1 - ) where {T} - itp = _fastinterp((gridx, gridy, gridz), A, order, bc) - return FieldInterpolator(itp, (gridx, gridy, gridz), bc) + return FieldInterpolator1D(itp, gridx, bc, dir) end function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} @@ -375,32 +348,20 @@ function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} return new_grid_vec, new_A end -function get_interpolator( - ::Type{<:StructuredGrid}, - A::AbstractArray{T, 4}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 - ) where {T} - As = reinterpret(reshape, SVector{3, T}, A) - return get_interpolator(StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) -end - -function get_interpolator( - ::Type{<:StructuredGrid}, - A::AbstractArray{T, 3}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 - ) where {T} - is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) +# Time-dependent field interpolation. - if is_uniform_r - itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) - else - gridϕ, A = _ensure_full_phi(gridϕ, A) - itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) - end +""" + LazyTimeInterpolator{T, F, L} - return SphericalFieldInterpolator(itp, (gridr, gridθ, gridϕ), bc) -end +A callable struct for handling time-dependent fields with lazy loading and linear time interpolation. -# Time-dependent field interpolation. +# Fields +- `times::Vector{T}`: Sorted vector of time points. +- `loader::L`: Function `i -> field` that loads the field at index `i`. +- `buffer::Dict{Int, F}`: Cache for loaded fields. +- `lock::ReentrantLock`: Lock for thread safety. +""" struct LazyTimeInterpolator{T, F, L} <: Function times::Vector{T} loader::L @@ -409,6 +370,7 @@ struct LazyTimeInterpolator{T, F, L} <: Function end function LazyTimeInterpolator(times::AbstractVector, loader::Function) + # Determine the field type by loading the first field f1 = loader(1) return _LazyTimeInterpolator(times, loader, f1) end @@ -422,19 +384,22 @@ function _LazyTimeInterpolator(times::AbstractVector, loader::Function, f1::F) w end function (itp::LazyTimeInterpolator)(x, t) + # Find the time interval [t1, t2] such that t1 <= t <= t2 (assume times is sorted) idx = searchsortedlast(itp.times, t) + # Handle out-of-bounds if idx == 0 - return _get_field!(itp, 1)(x) + return _get_field!(itp, 1)(x) # clamp to start elseif idx >= length(itp.times) - return _get_field!(itp, length(itp.times))(x) + return _get_field!(itp, length(itp.times))(x) # clamp to end end t1 = itp.times[idx] t2 = itp.times[idx + 1] - w = (t - t1) / (t2 - t1) + w = (t - t1) / (t2 - t1) # linear weights + # Load fields (lazily) f1 = _get_field!(itp, idx) f2 = _get_field!(itp, idx + 1) @@ -444,7 +409,9 @@ end function _get_field!(itp::LazyTimeInterpolator, idx::Int) return lock(itp.lock) do if !haskey(itp.buffer, idx) + # Remove far-away indices filter!(p -> abs(p.first - idx) <= 1, itp.buffer) + field = itp.loader(idx) itp.buffer[idx] = field end diff --git a/test/runtests.jl b/test/runtests.jl index 7d11c2f3e..65a68ecff 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -170,7 +170,7 @@ end r, θ, ϕ, B_sph = setup_spherical_field() param = prepare(r, θ, ϕ, zero_E, B_sph; gridtype = StructuredGrid) # Check field interpolation Bz - @test param[4](SA[1.0, 1.0, 1.0])[3] == 9.888387888463716e-9 + @test param[4](SA[1.0, 1.0, 1.0])[3] ≈ 9.888387888463716e-9 # Test Vector inputs for spherical grid r_vec = collect(r) @@ -181,7 +181,7 @@ end r_vec, θ_vec, ϕ_vec, zero_E, B_sph; species = Ion(m = 16, q = 1), gridtype = TP.StructuredGrid ) - @test param_vec[4](SA[1.0, 1.0, 1.0])[3] == 9.888387888463716e-9 + @test param_vec[4](SA[1.0, 1.0, 1.0])[3] ≈ 9.888387888463716e-9 end @testset "analytical field" begin diff --git a/test/test_utility.jl b/test/test_utility.jl index c45a68002..08b0107f6 100644 --- a/test/test_utility.jl +++ b/test/test_utility.jl @@ -50,23 +50,23 @@ import TestParticle as TP for i in eachindex(x), j in eachindex(y), k in eachindex(z) ] nfunc11 = TP.build_interpolator(n, x, y, z) - @test nfunc11(SA[9, 0, 0]) == 11.85 + @test nfunc11(SA[9, 0, 0]) ≈ 11.85f0 nfunc12 = TP.build_interpolator(n, x, y, z, 1, 2) - @test nfunc12(SA[20, 0, 0]) == 9.5 + @test nfunc12(SA[20, 0, 0]) ≈ 10.5f0 nfunc13 = TP.build_interpolator(n, x, y, z, 1, 3) - @test nfunc13(SA[20, 0, 0]) == 12.0 + @test nfunc13(SA[20, 0, 0]) ≈ 12.0f0 nfunc21 = TP.build_interpolator(n, x, y, z, 2) - @test nfunc21(SA[9, 0, 0]) == 11.898528302013874 + @test nfunc21(SA[9, 0, 0]) ≈ 11.85f0 nfunc22 = TP.build_interpolator(n, x, y, z, 2, 2) - @test nfunc22(SA[20, 0, 0]) == 9.166666686534882 + @test nfunc22(SA[20, 0, 0]) ≈ 10.5f0 nfunc23 = TP.build_interpolator(n, x, y, z, 2, 3) - @test nfunc23(SA[20, 0, 0]) == 12.14705765247345 + @test nfunc23(SA[20, 0, 0]) ≈ 12.0f0 nfunc31 = TP.build_interpolator(n, x, y, z, 3) - @test nfunc31(SA[9, 0, 0]) == 11.882999392215163 + @test nfunc31(SA[9, 0, 0]) ≈ 11.85f0 nfunc32 = TP.build_interpolator(n, x, y, z, 3, 2) - @test nfunc32(SA[20, 0, 0]) == 9.124999547351358 + @test nfunc32(SA[20, 0, 0]) ≈ 10.5f0 nfunc33 = TP.build_interpolator(n, x, y, z, 3, 3) - @test nfunc33(SA[20, 0, 0]) == 12.191176189381315 + @test nfunc33(SA[20, 0, 0]) ≈ 12.0f0 end begin # spherical interpolation From a463ff521cebecf24f01a6f1438c5229e2138cc8 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Wed, 4 Mar 2026 16:41:58 -0500 Subject: [PATCH 18/27] Temporary version before FastInterpolations.jl supports constant extrapolation --- src/TestParticle.jl | 2 + src/utility/fastinterpolation.jl | 183 +++++++++++-------------------- test/test_utility.jl | 15 ++- 3 files changed, 75 insertions(+), 125 deletions(-) diff --git a/src/TestParticle.jl b/src/TestParticle.jl index 9ca9970ba..aa3ce2519 100644 --- a/src/TestParticle.jl +++ b/src/TestParticle.jl @@ -1,6 +1,8 @@ module TestParticle using LinearAlgebra: norm, ×, ⋅, diag, normalize +using FastInterpolations: linear_interp, quadratic_interp, cubic_interp, constant_interp, + Extrap, NoExtrap using SciMLBase: AbstractODEProblem, AbstractODEFunction, AbstractODESolution, ReturnCode, BasicEnsembleAlgorithm, EnsembleThreads, EnsembleSerial, EnsembleDistributed, DEFAULT_SPECIALIZATION, ODEFunction, ODEProblem, remake, diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl index f9b7a348f..4a2cc5d33 100644 --- a/src/utility/fastinterpolation.jl +++ b/src/utility/fastinterpolation.jl @@ -1,13 +1,4 @@ -# Field interpolations using FastInterpolations.jl. - -using FastInterpolations -using StaticArrays -using Adapt - -Base.eltype(::FastInterpolations.AbstractInterpolant{Tx, Ty}) where {Tx, Ty} = Ty -Base.eltype(::FastInterpolations.AbstractInterpolantND{Tx, Ty, N}) where {Tx, Ty, N} = Ty - -@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) +# Field interpolations. """ AbstractFieldInterpolator @@ -17,29 +8,17 @@ Abstract type for all field interpolators. abstract type AbstractFieldInterpolator <: Function end """ - FieldInterpolator{T, G} + FieldInterpolator{T} -A callable struct that wraps a 3D interpolation object and its grid. +A callable struct that wraps a 3D interpolation object. """ -struct FieldInterpolator{T, G} <: AbstractFieldInterpolator +struct FieldInterpolator{T} <: AbstractFieldInterpolator itp::T - grid::G - bc::Int end -_in_bounds(x, gridx) = first(gridx) <= x <= last(gridx) - -function (fi::FieldInterpolator)(xu) - if fi.bc == 1 - if !(_in_bounds(xu[1], fi.grid[1]) && _in_bounds(xu[2], fi.grid[2]) && _in_bounds(xu[3], fi.grid[3])) - T_val = eltype(fi.itp) - if T_val <: SVector - return T_val(ntuple(_ -> NaN, Val(length(T_val)))) - else - return T_val(NaN) - end - end - end +const FieldInterpolator3D = FieldInterpolator + +@inbounds function (fi::FieldInterpolator)(xu) return fi.itp((xu[1], xu[2], xu[3])) end @@ -47,30 +26,19 @@ function (fi::FieldInterpolator)(xu, t) return fi(xu) end -Adapt.adapt_structure(to, fi::FieldInterpolator) = FieldInterpolator(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) +Adapt.adapt_structure(to, fi::FieldInterpolator) = FieldInterpolator(Adapt.adapt(to, fi.itp)) """ - FieldInterpolator2D{T, G} + FieldInterpolator2D{T} A callable struct that wraps a 2D interpolation object. """ -struct FieldInterpolator2D{T, G} <: AbstractFieldInterpolator +struct FieldInterpolator2D{T} <: AbstractFieldInterpolator itp::T - grid::G - bc::Int end -function (fi::FieldInterpolator2D)(xu) - if fi.bc == 1 - if !(_in_bounds(xu[1], fi.grid[1]) && _in_bounds(xu[2], fi.grid[2])) - T_val = eltype(fi.itp) - if T_val <: SVector - return T_val(ntuple(_ -> NaN, Val(length(T_val)))) - else - return T_val(NaN) - end - end - end +@inbounds function (fi::FieldInterpolator2D)(xu) + # 2D interpolation usually involves x and y return fi.itp((xu[1], xu[2])) end @@ -78,31 +46,19 @@ function (fi::FieldInterpolator2D)(xu, t) return fi(xu) end -Adapt.adapt_structure(to, fi::FieldInterpolator2D) = FieldInterpolator2D(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) +Adapt.adapt_structure(to, fi::FieldInterpolator2D) = FieldInterpolator2D(Adapt.adapt(to, fi.itp)) """ - FieldInterpolator1D{T, G} + FieldInterpolator1D{T} A callable struct that wraps a 1D interpolation object. """ -struct FieldInterpolator1D{T, G} <: AbstractFieldInterpolator +struct FieldInterpolator1D{T} <: AbstractFieldInterpolator itp::T - grid::G - bc::Int dir::Int end -function (fi::FieldInterpolator1D)(xu) - if fi.bc == 1 - if !_in_bounds(xu[fi.dir], fi.grid) - T_val = eltype(fi.itp) - if T_val <: SVector - return T_val(ntuple(_ -> NaN, Val(length(T_val)))) - else - return T_val(NaN) - end - end - end +@inbounds function (fi::FieldInterpolator1D)(xu) return fi.itp((xu[fi.dir],)) end @@ -110,37 +66,24 @@ function (fi::FieldInterpolator1D)(xu, t) return fi(xu) end -Adapt.adapt_structure(to, fi::FieldInterpolator1D) = FieldInterpolator1D(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc, fi.dir) +Adapt.adapt_structure(to, fi::FieldInterpolator1D) = FieldInterpolator1D(Adapt.adapt(to, fi.itp), fi.dir) """ - SphericalFieldInterpolator{T, G} + SphericalFieldInterpolator{T} -A callable struct for spherical grid interpolation. +A callable struct for spherical grid interpolation (scalar or combined vector). """ -struct SphericalFieldInterpolator{T, G} <: AbstractFieldInterpolator +struct SphericalFieldInterpolator{T} <: AbstractFieldInterpolator itp::T - grid::G - bc::Int end function (fi::SphericalFieldInterpolator)(xu) - r_val, θ_val, ϕ_val = cart2sph(xu) - - if fi.bc == 1 - if !(_in_bounds(r_val, fi.grid[1]) && _in_bounds(θ_val, fi.grid[2]) && _in_bounds(ϕ_val, fi.grid[3])) - T_val = eltype(fi.itp) - if T_val <: SVector - return T_val(ntuple(_ -> NaN, Val(length(T_val)))) - else - return T_val(NaN) - end - end - end - - res = fi.itp((r_val, θ_val, ϕ_val)) - if typeof(res) <: SVector || length(res) > 1 + rθϕ = cart2sph(xu) + res = fi.itp(rθϕ) + if length(res) > 1 + # Convert vector result from spherical to cartesian basis Br, Bθ, Bϕ = res - return sph_to_cart_vector(Br, Bθ, Bϕ, θ_val, ϕ_val) + return sph_to_cart_vector(Br, Bθ, Bϕ, rθϕ[2], rθϕ[3]) else return res end @@ -150,7 +93,7 @@ function (fi::SphericalFieldInterpolator)(xu, t) return fi(xu) end -Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp), Adapt.adapt(to, fi.grid), fi.bc) +Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp)) function _get_extrap_mode(bc) if bc == 2 @@ -158,42 +101,32 @@ function _get_extrap_mode(bc) elseif bc == 3 return Extrap(:constant) else - # For bc == 1, we handle NaN in wrapper and use NoExtrap() for inner to avoid errors + # TODO: bc == 1 (NaN outside domain) requires native FastInterpolations.jl support. + # Once available, replace this with the appropriate Extrap mode and remove the + # manual bounds-checking code that was previously in the FieldInterpolator call methods. return NoExtrap() end end -function _match_grid_type(g::AbstractRange, ::Type{T}) where {T <: AbstractFloat} - return range(T(first(g)), T(last(g)), length = length(g)) -end -function _match_grid_type(g::AbstractVector, ::Type{T}) where {T <: AbstractFloat} - return T.(g) -end -function _match_grid_type(g::Tuple, ::Type{T}) where {T <: AbstractFloat} - return T.(g) -end function _fastinterp(grids, A, order, bc) - T_A = eltype(A) - T_F = T_A <: SVector ? eltype(T_A) : T_A - T_F = T_F <: AbstractFloat ? T_F : Float64 - matched_grids = map(g -> _match_grid_type(g, T_F), grids) - extrap_mode = _get_extrap_mode(bc) if order == 1 - return linear_interp(matched_grids, A; extrap = extrap_mode) + return linear_interp(grids, A; extrap = extrap_mode) elseif order == 2 - return quadratic_interp(matched_grids, A; extrap = extrap_mode) + return quadratic_interp(grids, A; extrap = extrap_mode) elseif order == 3 - return cubic_interp(matched_grids, A; extrap = extrap_mode) + return cubic_interp(grids, A; extrap = extrap_mode) else - return constant_interp(matched_grids, A; extrap = extrap_mode) + return constant_interp(grids, A; extrap = extrap_mode) end end +@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) + """ - build_interpolator(gridtype, A, grids..., order::Int=1, bc::Int=1) - build_interpolator(A, grids..., order::Int=1, bc::Int=1) + build_interpolator(gridtype, A, grids..., order::Int=1, bc::Int=3) + build_interpolator(A, grids..., order::Int=1, bc::Int=3) Return a function for interpolating field array `A` on the given grids. @@ -202,14 +135,15 @@ Return a function for interpolating field array `A` on the given grids. - `gridtype`: `CartesianGrid`, `RectilinearGrid` or `StructuredGrid`. Usually determined by the number of grids. - `A`: field array. For vector field, the first dimension should be 3 if it's not an SVector wrapper. - `order::Int=1`: order of interpolation in [1,2,3]. - - `bc::Int=1`: type of boundary conditions, 1 -> NaN, 2 -> periodic, 3 -> Flat. + - `bc::Int=3`: type of boundary conditions, 1 -> NaN (not yet native; requires FastInterpolations support), 2 -> periodic, 3 -> Flat. # Notes The input array `A` may be modified in-place for memory optimization. """ function build_interpolator( ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 3 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) @@ -218,18 +152,21 @@ end function build_interpolator( ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 3 ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." end itp = _fastinterp((gridx, gridy, gridz), A, order, bc) - return FieldInterpolator(itp, (gridx, gridy, gridz), bc) + # Return field value at a given location. + return FieldInterpolator(itp) end function build_interpolator( ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 3 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) @@ -238,7 +175,8 @@ end function build_interpolator( ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, - gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 + gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, + order::Int = 1, bc::Int = 3 ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." @@ -248,7 +186,7 @@ function build_interpolator( end itp = _fastinterp((gridx, gridy, gridz), A, order, bc) - return FieldInterpolator(itp, (gridx, gridy, gridz), bc) + return FieldInterpolator(itp) end function build_interpolator( @@ -267,20 +205,24 @@ function build_interpolator( if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." end - + # Detect if uniform grid (old Spherical) or non-uniform r (old SphericalNonUniformR) + # We check if gridr is an AbstractRange, e.g. Base.LogRange is an AbstractRange but not uniform! is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) if is_uniform_r itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) - else + else # Non-uniform R (SphericalNonUniformR behavior) gridϕ, A = _ensure_full_phi(gridϕ, A) itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) end - return SphericalFieldInterpolator(itp, (gridr, gridθ, gridϕ), bc) + return SphericalFieldInterpolator(itp) end -function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 2) +function build_interpolator( + ::Type{<:CartesianGrid}, A, + gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 3 + ) if eltype(A) <: SVector @assert ndims(A) == 2 "Inconsistent 2D force field and grid! Expected 2D array of SVectors." As = A @@ -290,10 +232,13 @@ function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, g end itp = _fastinterp((gridx, gridy), As, order, bc) - return FieldInterpolator2D(itp, (gridx, gridy), bc) + return FieldInterpolator2D(itp) end -function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, order::Int = 1, bc::Int = 3; dir = 1) +function build_interpolator( + ::Type{<:CartesianGrid}, A, gridx::AbstractVector, + order::Int = 1, bc::Int = 3; dir = 1 + ) if eltype(A) <: SVector @assert ndims(A) == 1 "Inconsistent 1D force field and grid! Expected 1D array of SVectors." As = A @@ -304,7 +249,7 @@ function build_interpolator(::Type{<:CartesianGrid}, A, gridx::AbstractVector, o itp = _fastinterp((gridx,), As, order, bc) - return FieldInterpolator1D(itp, gridx, bc, dir) + return FieldInterpolator1D(itp, dir) end function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} @@ -340,7 +285,7 @@ function _ensure_full_phi(gridϕ, A::AbstractArray{T, N}) where {T, N} if needs_2pi if needs_0 selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(new_A, phi_dim, 1) - else + else # needs 2π only selectdim(new_A, phi_dim, size(new_A, phi_dim)) .= selectdim(A, phi_dim, 1) end end diff --git a/test/test_utility.jl b/test/test_utility.jl index 08b0107f6..e1ea6453f 100644 --- a/test/test_utility.jl +++ b/test/test_utility.jl @@ -49,20 +49,23 @@ import TestParticle as TP Float32(i + j + k) for i in eachindex(x), j in eachindex(y), k in eachindex(z) ] - nfunc11 = TP.build_interpolator(n, x, y, z) - @test nfunc11(SA[9, 0, 0]) ≈ 11.85f0 + # TODO: re-enable when FastInterpolations.jl natively supports bc=1 (NaN outside domain). + # nfunc11 = TP.build_interpolator(n, x, y, z) + # @test nfunc11(SA[9, 0, 0]) ≈ 11.85f0 nfunc12 = TP.build_interpolator(n, x, y, z, 1, 2) @test nfunc12(SA[20, 0, 0]) ≈ 10.5f0 nfunc13 = TP.build_interpolator(n, x, y, z, 1, 3) @test nfunc13(SA[20, 0, 0]) ≈ 12.0f0 - nfunc21 = TP.build_interpolator(n, x, y, z, 2) - @test nfunc21(SA[9, 0, 0]) ≈ 11.85f0 + # TODO: re-enable when FastInterpolations.jl natively supports bc=1 (NaN outside domain). + # nfunc21 = TP.build_interpolator(n, x, y, z, 2) + # @test nfunc21(SA[9, 0, 0]) ≈ 11.85f0 nfunc22 = TP.build_interpolator(n, x, y, z, 2, 2) @test nfunc22(SA[20, 0, 0]) ≈ 10.5f0 nfunc23 = TP.build_interpolator(n, x, y, z, 2, 3) @test nfunc23(SA[20, 0, 0]) ≈ 12.0f0 - nfunc31 = TP.build_interpolator(n, x, y, z, 3) - @test nfunc31(SA[9, 0, 0]) ≈ 11.85f0 + # TODO: re-enable when FastInterpolations.jl natively supports bc=1 (NaN outside domain). + # nfunc31 = TP.build_interpolator(n, x, y, z, 3) + # @test nfunc31(SA[9, 0, 0]) ≈ 11.85f0 nfunc32 = TP.build_interpolator(n, x, y, z, 3, 2) @test nfunc32(SA[20, 0, 0]) ≈ 10.5f0 nfunc33 = TP.build_interpolator(n, x, y, z, 3, 3) From 2da4f079b2710f897462d46aa249899dbb0644f8 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Sun, 1 Mar 2026 10:24:15 -0500 Subject: [PATCH 19/27] refactor: substitute Interpolations.jl with FastInterpolations.jl --- src/utility/fastinterpolation.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl index 4a2cc5d33..1283887e1 100644 --- a/src/utility/fastinterpolation.jl +++ b/src/utility/fastinterpolation.jl @@ -108,7 +108,6 @@ function _get_extrap_mode(bc) end end - function _fastinterp(grids, A, order, bc) extrap_mode = _get_extrap_mode(bc) if order == 1 From c7c0a339f5a665ffcab4fe2cd93b0058d8246c61 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Sun, 1 Mar 2026 22:50:15 -0500 Subject: [PATCH 20/27] Try to fix the Float32 vs Float64 bug --- src/utility/fastinterpolation.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl index 1283887e1..9a21a914f 100644 --- a/src/utility/fastinterpolation.jl +++ b/src/utility/fastinterpolation.jl @@ -109,15 +109,20 @@ function _get_extrap_mode(bc) end function _fastinterp(grids, A, order, bc) + T_A = eltype(A) + T_F = T_A <: SVector ? eltype(T_A) : T_A + T_F = T_F <: AbstractFloat ? T_F : Float64 + matched_grids = map(g -> _match_grid_type(g, T_F), grids) + extrap_mode = _get_extrap_mode(bc) if order == 1 - return linear_interp(grids, A; extrap = extrap_mode) + return linear_interp(matched_grids, A; extrap = extrap_mode) elseif order == 2 - return quadratic_interp(grids, A; extrap = extrap_mode) + return quadratic_interp(matched_grids, A; extrap = extrap_mode) elseif order == 3 return cubic_interp(grids, A; extrap = extrap_mode) else - return constant_interp(grids, A; extrap = extrap_mode) + return constant_interp(matched_grids, A; extrap = extrap_mode) end end From 585a7024ba42eac95f048adbd62ef20cd9b3edb3 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Wed, 4 Mar 2026 21:30:52 -0500 Subject: [PATCH 21/27] Simplify spherical grid support --- src/TestParticle.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TestParticle.jl b/src/TestParticle.jl index aa3ce2519..1d42727c5 100644 --- a/src/TestParticle.jl +++ b/src/TestParticle.jl @@ -2,7 +2,7 @@ module TestParticle using LinearAlgebra: norm, ×, ⋅, diag, normalize using FastInterpolations: linear_interp, quadratic_interp, cubic_interp, constant_interp, - Extrap, NoExtrap + Extrap, NoExtrap, PeriodicBC, ZeroCurvBC using SciMLBase: AbstractODEProblem, AbstractODEFunction, AbstractODESolution, ReturnCode, BasicEnsembleAlgorithm, EnsembleThreads, EnsembleSerial, EnsembleDistributed, DEFAULT_SPECIALIZATION, ODEFunction, ODEProblem, remake, From d3119a45cd9c44ff9a4773d15280c357a367c6f1 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Wed, 4 Mar 2026 21:39:00 -0500 Subject: [PATCH 22/27] Update interpolation.jl --- src/utility/interpolation.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/utility/interpolation.jl b/src/utility/interpolation.jl index cbec3f8e3..5bb31ca5d 100644 --- a/src/utility/interpolation.jl +++ b/src/utility/interpolation.jl @@ -231,6 +231,9 @@ function build_interpolator( itp = scale(itp_obj, gridr, gridθ, gridϕ) end + bctype = (Flat(), Flat(), phi_bc) + itp = extrapolate(interpolate!((gridr, gridθ, gridϕ), A, Gridded(Linear())), bctype) + return SphericalFieldInterpolator(itp) end From 99267912c4c27ae72f9cdc7a5abb1868306ecbac Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Thu, 5 Mar 2026 10:49:11 -0500 Subject: [PATCH 23/27] Experimenting backend support for both Interpolations.jl and FastInterpolations.jl --- Project.toml | 2 + src/TestParticle.jl | 4 + src/prepare.jl | 8 +- src/utility/fastinterpolation.jl | 48 ++-- src/utility/interp_backends.jl | 15 ++ src/utility/interpolation.jl | 339 +++++++--------------------- test/Project.toml | 1 + test/runtests.jl | 2 + test/test_interpolations_backend.jl | 105 +++++++++ 9 files changed, 240 insertions(+), 284 deletions(-) create mode 100644 src/utility/interp_backends.jl create mode 100644 test/test_interpolations_backend.jl diff --git a/Project.toml b/Project.toml index 2a1451034..c691c80b1 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" FastInterpolations = "9ea80cae-fc13-4c00-8066-6eaedb12f34b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa" @@ -40,6 +41,7 @@ DiffResults = "1" Distributed = "1" FastInterpolations = "0.4" ForwardDiff = "1" +Interpolations = "0.16" KernelAbstractions = "0.9" LinearAlgebra = "1" Meshes = "0.55, 0.56" diff --git a/src/TestParticle.jl b/src/TestParticle.jl index 1d42727c5..fc56399d0 100644 --- a/src/TestParticle.jl +++ b/src/TestParticle.jl @@ -3,6 +3,7 @@ module TestParticle using LinearAlgebra: norm, ×, ⋅, diag, normalize using FastInterpolations: linear_interp, quadratic_interp, cubic_interp, constant_interp, Extrap, NoExtrap, PeriodicBC, ZeroCurvBC +import Interpolations using SciMLBase: AbstractODEProblem, AbstractODEFunction, AbstractODESolution, ReturnCode, BasicEnsembleAlgorithm, EnsembleThreads, EnsembleSerial, EnsembleDistributed, DEFAULT_SPECIALIZATION, ODEFunction, ODEProblem, remake, @@ -40,12 +41,15 @@ export get_gyrofrequency, export orbit, monitor export get_fields, get_work export LazyTimeInterpolator +export AbstractInterpolationBackend, FastInterpolationsBackend, InterpolationsBackend export TraceProblem, TraceGCProblem, TraceHybridProblem, CartesianGrid, RectilinearGrid, StructuredGrid export EnsembleSerial, EnsembleThreads, EnsembleDistributed, remake include("types.jl") include("utility/utility.jl") +include("utility/interp_backends.jl") include("utility/fastinterpolation.jl") +include("utility/interpolation.jl") include("sampler.jl") include("prepare.jl") include("gc.jl") diff --git a/src/prepare.jl b/src/prepare.jl index 01ff73475..172d2eff4 100644 --- a/src/prepare.jl +++ b/src/prepare.jl @@ -64,8 +64,12 @@ get_EField(param) = param[3] prepare_field(f, args...; kwargs...) = Field(f) prepare_field(f::ZeroField, args...; kwargs...) = f -function prepare_field(f::AbstractArray, x...; gridtype, order, bc, kw...) - return Field(build_interpolator(gridtype, f, x..., order, bc; kw...)) +function prepare_field( + f::AbstractArray, x...; + gridtype, order, bc, + backend::AbstractInterpolationBackend = FastInterpolationsBackend(), kw... + ) + return Field(build_interpolator(backend, gridtype, f, x..., order, bc; kw...)) end function _prepare( diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl index 9a21a914f..cc3ea1fba 100644 --- a/src/utility/fastinterpolation.jl +++ b/src/utility/fastinterpolation.jl @@ -38,7 +38,6 @@ struct FieldInterpolator2D{T} <: AbstractFieldInterpolator end @inbounds function (fi::FieldInterpolator2D)(xu) - # 2D interpolation usually involves x and y return fi.itp((xu[1], xu[2])) end @@ -81,7 +80,6 @@ function (fi::SphericalFieldInterpolator)(xu) rθϕ = cart2sph(xu) res = fi.itp(rθϕ) if length(res) > 1 - # Convert vector result from spherical to cartesian basis Br, Bθ, Bϕ = res return sph_to_cart_vector(Br, Bθ, Bϕ, rθϕ[2], rθϕ[3]) else @@ -109,26 +107,25 @@ function _get_extrap_mode(bc) end function _fastinterp(grids, A, order, bc) - T_A = eltype(A) - T_F = T_A <: SVector ? eltype(T_A) : T_A - T_F = T_F <: AbstractFloat ? T_F : Float64 - matched_grids = map(g -> _match_grid_type(g, T_F), grids) - extrap_mode = _get_extrap_mode(bc) if order == 1 - return linear_interp(matched_grids, A; extrap = extrap_mode) + return linear_interp(grids, A; extrap = extrap_mode) elseif order == 2 - return quadratic_interp(matched_grids, A; extrap = extrap_mode) + return quadratic_interp(grids, A; extrap = extrap_mode) elseif order == 3 return cubic_interp(grids, A; extrap = extrap_mode) else - return constant_interp(matched_grids, A; extrap = extrap_mode) + return constant_interp(grids, A; extrap = extrap_mode) end end -@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) +@inline build_interpolator(A, grid1, args...) = + build_interpolator(FastInterpolationsBackend(), CartesianGrid, A, grid1, args...) +@inline build_interpolator(gridtype::Type, A, args...) = + build_interpolator(FastInterpolationsBackend(), gridtype, A, args...) """ + build_interpolator(backend, gridtype, A, grids..., order::Int=1, bc::Int=3) build_interpolator(gridtype, A, grids..., order::Int=1, bc::Int=3) build_interpolator(A, grids..., order::Int=1, bc::Int=3) @@ -136,6 +133,7 @@ Return a function for interpolating field array `A` on the given grids. # Arguments + - `backend`: `FastInterpolationsBackend()` (default) or `InterpolationsBackend()`. - `gridtype`: `CartesianGrid`, `RectilinearGrid` or `StructuredGrid`. Usually determined by the number of grids. - `A`: field array. For vector field, the first dimension should be 3 if it's not an SVector wrapper. - `order::Int=1`: order of interpolation in [1,2,3]. @@ -145,17 +143,17 @@ Return a function for interpolating field array `A` on the given grids. The input array `A` may be modified in-place for memory optimization. """ function build_interpolator( - ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, + b::FastInterpolationsBackend, ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 3 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) - return build_interpolator(CartesianGrid, As, gridx, gridy, gridz, order, bc) + return build_interpolator(b, CartesianGrid, As, gridx, gridy, gridz, order, bc) end function build_interpolator( - ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, + ::FastInterpolationsBackend, ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 3 ) where {T} @@ -163,22 +161,21 @@ function build_interpolator( @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." end itp = _fastinterp((gridx, gridy, gridz), A, order, bc) - # Return field value at a given location. return FieldInterpolator(itp) end function build_interpolator( - ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, + b::FastInterpolationsBackend, ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 3 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) - return build_interpolator(RectilinearGrid, As, gridx, gridy, gridz, order, bc) + return build_interpolator(b, RectilinearGrid, As, gridx, gridy, gridz, order, bc) end function build_interpolator( - ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, + ::FastInterpolationsBackend, ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 3 ) where {T} @@ -194,28 +191,26 @@ function build_interpolator( end function build_interpolator( - ::Type{<:StructuredGrid}, A::AbstractArray{T, 4}, + b::FastInterpolationsBackend, ::Type{<:StructuredGrid}, A::AbstractArray{T, 4}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) - return build_interpolator(StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) + return build_interpolator(b, StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) end function build_interpolator( - ::Type{<:StructuredGrid}, A::AbstractArray{T, 3}, + ::FastInterpolationsBackend, ::Type{<:StructuredGrid}, A::AbstractArray{T, 3}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." end - # Detect if uniform grid (old Spherical) or non-uniform r (old SphericalNonUniformR) - # We check if gridr is an AbstractRange, e.g. Base.LogRange is an AbstractRange but not uniform! is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) if is_uniform_r itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) - else # Non-uniform R (SphericalNonUniformR behavior) + else gridϕ, A = _ensure_full_phi(gridϕ, A) itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) end @@ -224,7 +219,7 @@ function build_interpolator( end function build_interpolator( - ::Type{<:CartesianGrid}, A, + ::FastInterpolationsBackend, ::Type{<:CartesianGrid}, A, gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 3 ) if eltype(A) <: SVector @@ -240,7 +235,7 @@ function build_interpolator( end function build_interpolator( - ::Type{<:CartesianGrid}, A, gridx::AbstractVector, + ::FastInterpolationsBackend, ::Type{<:CartesianGrid}, A, gridx::AbstractVector, order::Int = 1, bc::Int = 3; dir = 1 ) if eltype(A) <: SVector @@ -252,7 +247,6 @@ function build_interpolator( end itp = _fastinterp((gridx,), As, order, bc) - return FieldInterpolator1D(itp, dir) end diff --git a/src/utility/interp_backends.jl b/src/utility/interp_backends.jl new file mode 100644 index 000000000..62e34a534 --- /dev/null +++ b/src/utility/interp_backends.jl @@ -0,0 +1,15 @@ +abstract type AbstractInterpolationBackend end + +""" + FastInterpolationsBackend + +Interpolation backend using FastInterpolations.jl (default). +""" +struct FastInterpolationsBackend <: AbstractInterpolationBackend end + +""" + InterpolationsBackend + +Interpolation backend using Interpolations.jl. +""" +struct InterpolationsBackend <: AbstractInterpolationBackend end diff --git a/src/utility/interpolation.jl b/src/utility/interpolation.jl index 5bb31ca5d..99a853306 100644 --- a/src/utility/interpolation.jl +++ b/src/utility/interpolation.jl @@ -1,130 +1,75 @@ -# Field interpolations. +# Field interpolations using the Interpolations.jl backend. """ - AbstractFieldInterpolator + TupleCallAdaptor{T} -Abstract type for all field interpolators. +Wraps an Interpolations.jl interpolant so it can be called with a single +tuple argument, matching the call convention used by `FieldInterpolator`. """ -abstract type AbstractFieldInterpolator <: Function end - -""" - FieldInterpolator{T} - -A callable struct that wraps a 3D interpolation object. -""" -struct FieldInterpolator{T} <: AbstractFieldInterpolator - itp::T -end - -const FieldInterpolator3D = FieldInterpolator - -@inbounds function (fi::FieldInterpolator)(xu) - return fi.itp(xu[1], xu[2], xu[3]) -end - -function (fi::FieldInterpolator)(xu, t) - return fi(xu) -end - -Adapt.adapt_structure(to, fi::FieldInterpolator) = FieldInterpolator(Adapt.adapt(to, fi.itp)) - -""" - FieldInterpolator2D{T} - -A callable struct that wraps a 2D interpolation object. -""" -struct FieldInterpolator2D{T} <: AbstractFieldInterpolator +struct TupleCallAdaptor{T} itp::T end -@inbounds function (fi::FieldInterpolator2D)(xu) - # 2D interpolation usually involves x and y - return fi.itp(xu[1], xu[2]) -end - -function (fi::FieldInterpolator2D)(xu, t) - return fi(xu) -end - -Adapt.adapt_structure(to, fi::FieldInterpolator2D) = FieldInterpolator2D(Adapt.adapt(to, fi.itp)) +Adapt.adapt_structure(to, a::TupleCallAdaptor) = TupleCallAdaptor(Adapt.adapt(to, a.itp)) -""" - FieldInterpolator1D{T} - -A callable struct that wraps a 1D interpolation object. -""" -struct FieldInterpolator1D{T} <: AbstractFieldInterpolator - itp::T - dir::Int -end +@inline (a::TupleCallAdaptor{T})(t::NTuple{1}) where {T} = a.itp(t[1]) +@inline (a::TupleCallAdaptor{T})(t::NTuple{2}) where {T} = a.itp(t[1], t[2]) +@inline (a::TupleCallAdaptor{T})(t::NTuple{3}) where {T} = a.itp(t[1], t[2], t[3]) +@inline (a::TupleCallAdaptor{T})(t::AbstractVector) where {T} = a.itp(t[1], t[2], t[3]) -@inbounds function (fi::FieldInterpolator1D)(xu) - return fi.itp(xu[fi.dir]) -end +function _get_bspline(order::Int, periodic::Bool) + gt = Interpolations.OnCell() + interp_type = if order == 1 + Interpolations.Linear + elseif order == 2 + Interpolations.Quadratic + elseif order == 3 + Interpolations.Cubic + else + throw(ArgumentError("Unsupported interpolation order!")) + end -function (fi::FieldInterpolator1D)(xu, t) - return fi(xu) + if periodic + return Interpolations.BSpline(interp_type(Interpolations.Periodic(gt))) + else + if interp_type == Interpolations.Linear + return Interpolations.BSpline(Interpolations.Linear()) + else + return Interpolations.BSpline(interp_type(Interpolations.Flat(gt))) + end + end end -Adapt.adapt_structure(to, fi::FieldInterpolator1D) = FieldInterpolator1D(Adapt.adapt(to, fi.itp), fi.dir) - -""" - SphericalFieldInterpolator{T} - -A callable struct for spherical grid interpolation (scalar or combined vector). -""" -struct SphericalFieldInterpolator{T} <: AbstractFieldInterpolator - itp::T -end +function _get_interp_object(A, order::Int, bc::Int) + bspline = _get_bspline(order, bc == 2) -function (fi::SphericalFieldInterpolator)(xu) - r_val, θ_val, ϕ_val = cart2sph(xu) - res = fi.itp(r_val, θ_val, ϕ_val) - if length(res) > 1 - # Convert vector result from spherical to cartesian basis - Br, Bθ, Bϕ = res - return sph_to_cart_vector(Br, Bθ, Bϕ, θ_val, ϕ_val) + bctype = if bc == 1 + if eltype(A) <: SVector + SVector{3, eltype(eltype(A))}(NaN, NaN, NaN) + else + eltype(A)(NaN) + end + elseif bc == 2 + Interpolations.Periodic() else - return res + Interpolations.Flat() end -end -function (fi::SphericalFieldInterpolator)(xu, t) - return fi(xu) + return Interpolations.extrapolate(Interpolations.interpolate(A, bspline), bctype) end -Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp)) - -@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) - -""" - build_interpolator(gridtype, A, grids..., order::Int=1, bc::Int=1) - build_interpolator(A, grids..., order::Int=1, bc::Int=1) - -Return a function for interpolating field array `A` on the given grids. - -# Arguments - - - `gridtype`: `CartesianGrid`, `RectilinearGrid` or `StructuredGrid`. Usually determined by the number of grids. - - `A`: field array. For vector field, the first dimension should be 3 if it's not an SVector wrapper. - - `order::Int=1`: order of interpolation in [1,2,3]. - - `bc::Int=1`: type of boundary conditions, 1 -> NaN, 2 -> periodic, 3 -> Flat. - -# Notes -The input array `A` may be modified in-place for memory optimization. -""" function build_interpolator( - ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, + b::InterpolationsBackend, ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) - return build_interpolator(CartesianGrid, As, gridx, gridy, gridz, order, bc) + return build_interpolator(b, CartesianGrid, As, gridx, gridy, gridz, order, bc) end function build_interpolator( - ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, + ::InterpolationsBackend, ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 ) where {T} @@ -132,24 +77,22 @@ function build_interpolator( @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." end itp = _get_interp_object(A, order, bc) - interp = scale(itp, gridx, gridy, gridz) - - # Return field value at a given location. - return FieldInterpolator(interp) + interp = Interpolations.scale(itp, gridx, gridy, gridz) + return FieldInterpolator(TupleCallAdaptor(interp)) end function build_interpolator( - ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, + b::InterpolationsBackend, ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) - return build_interpolator(RectilinearGrid, As, gridx, gridy, gridz, order, bc) + return build_interpolator(b, RectilinearGrid, As, gridx, gridy, gridz, order, bc) end function build_interpolator( - ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, + ::InterpolationsBackend, ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 ) where {T} @@ -167,27 +110,29 @@ function build_interpolator( T(NaN) end elseif bc == 2 - Periodic() + Interpolations.Periodic() else - Flat() + Interpolations.Flat() end - itp = extrapolate(interpolate!((gridx, gridy, gridz), A, Gridded(Linear())), bctype) - - return FieldInterpolator(itp) + itp = Interpolations.extrapolate( + Interpolations.interpolate!((gridx, gridy, gridz), A, Interpolations.Gridded(Interpolations.Linear())), + bctype + ) + return FieldInterpolator(TupleCallAdaptor(itp)) end function build_interpolator( - ::Type{<:StructuredGrid}, A::AbstractArray{T, 4}, + b::InterpolationsBackend, ::Type{<:StructuredGrid}, A::AbstractArray{T, 4}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) - return build_interpolator(StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) + return build_interpolator(b, StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) end function build_interpolator( - ::Type{<:StructuredGrid}, A::AbstractArray{T, 3}, + ::InterpolationsBackend, ::Type{<:StructuredGrid}, A::AbstractArray{T, 3}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 ) where {T} if eltype(A) <: SVector @@ -205,40 +150,43 @@ function build_interpolator( has_2pi = isapprox(ϕ_max, 2π, atol = 1.0e-5) ϕ_bc = if has_0 && has_2pi - Periodic(OnGrid()) + Interpolations.Periodic(Interpolations.OnGrid()) else - Periodic(OnCell()) + Interpolations.Periodic(Interpolations.OnCell()) end - bctype = (Flat(), Flat(), ϕ_bc) + bctype = (Interpolations.Flat(), Interpolations.Flat(), ϕ_bc) if order == 1 - itp = extrapolate(interpolate!((gridr, gridθ, gridϕ), A, Gridded(Linear())), bctype) + itp = Interpolations.extrapolate( + Interpolations.interpolate!( + (gridr, gridθ, gridϕ), A, + Interpolations.Gridded(Interpolations.Linear()) + ), + bctype + ) else interp_type = if order == 2 - Quadratic + Interpolations.Quadratic elseif order == 3 - Cubic + Interpolations.Cubic else throw(ArgumentError("Unsupported interpolation order!")) end itp_type = ( - BSpline(interp_type(Flat(OnCell()))), - BSpline(interp_type(Flat(OnCell()))), - BSpline(interp_type(ϕ_bc)), + Interpolations.BSpline(interp_type(Interpolations.Flat(Interpolations.OnCell()))), + Interpolations.BSpline(interp_type(Interpolations.Flat(Interpolations.OnCell()))), + Interpolations.BSpline(interp_type(ϕ_bc)), ) - itp_obj = extrapolate(interpolate(A, itp_type), bctype) - itp = scale(itp_obj, gridr, gridθ, gridϕ) + itp_obj = Interpolations.extrapolate(Interpolations.interpolate(A, itp_type), bctype) + itp = Interpolations.scale(itp_obj, gridr, gridθ, gridϕ) end - bctype = (Flat(), Flat(), phi_bc) - itp = extrapolate(interpolate!((gridr, gridθ, gridϕ), A, Gridded(Linear())), bctype) - - return SphericalFieldInterpolator(itp) + return SphericalFieldInterpolator(TupleCallAdaptor(itp)) end function build_interpolator( - ::Type{<:CartesianGrid}, A, + ::InterpolationsBackend, ::Type{<:CartesianGrid}, A, gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 2 ) if eltype(A) <: SVector @@ -250,13 +198,12 @@ function build_interpolator( end itp = _get_interp_object(As, order, bc) - interp = scale(itp, gridx, gridy) - - return FieldInterpolator2D(interp) + interp = Interpolations.scale(itp, gridx, gridy) + return FieldInterpolator2D(TupleCallAdaptor(interp)) end function build_interpolator( - ::Type{<:CartesianGrid}, A, gridx::AbstractVector, + ::InterpolationsBackend, ::Type{<:CartesianGrid}, A, gridx::AbstractVector, order::Int = 1, bc::Int = 3; dir = 1 ) if eltype(A) <: SVector @@ -268,124 +215,6 @@ function build_interpolator( end itp = _get_interp_object(As, order, bc) - interp = scale(itp, gridx) - - return FieldInterpolator1D(interp, dir) -end - -# Internal Helpers - -function _get_bspline(order::Int, periodic::Bool) - gt = OnCell() - - interp_type = if order == 1 - Linear - elseif order == 2 - Quadratic - elseif order == 3 - Cubic - else - throw(ArgumentError("Unsupported interpolation order!")) - end - - if periodic - return BSpline(interp_type(Periodic(gt))) - else - # Linear() is special as it doesn't take an argument. - if interp_type == Linear - return BSpline(Linear()) - else - return BSpline(interp_type(Flat(gt))) - end - end -end - -function _get_interp_object(A, order::Int, bc::Int) - bspline = _get_bspline(order, bc == 2) - - bctype = if bc == 1 - if eltype(A) <: SVector - SVector{3, eltype(eltype(A))}(NaN, NaN, NaN) - else - eltype(eltype(A))(NaN) - end - elseif bc == 2 - Periodic() - else - Flat() - end - - return extrapolate(interpolate(A, bspline), bctype) -end - - -# Time-dependent field interpolation. - -""" - LazyTimeInterpolator{T, F, L} - -A callable struct for handling time-dependent fields with lazy loading and linear time interpolation. - -# Fields - -- `times::Vector{T}`: Sorted vector of time points. -- `loader::L`: Function `i -> field` that loads the field at index `i`. -- `buffer::Dict{Int, F}`: Cache for loaded fields. -- `lock::ReentrantLock`: Lock for thread safety. -""" -struct LazyTimeInterpolator{T, F, L} <: Function - times::Vector{T} - loader::L - buffer::Dict{Int, F} - lock::ReentrantLock -end - -function LazyTimeInterpolator(times::AbstractVector, loader::Function) - # Determine the field type by loading the first field - f1 = loader(1) - return _LazyTimeInterpolator(times, loader, f1) -end - -function _LazyTimeInterpolator(times::AbstractVector, loader::Function, f1::F) where {F} - buffer = Dict{Int, F}(1 => f1) - lock = ReentrantLock() - return LazyTimeInterpolator{eltype(times), F, typeof(loader)}( - times, loader, buffer, lock - ) -end - -function (itp::LazyTimeInterpolator)(x, t) - # Find the time interval [t1, t2] such that t1 <= t <= t2 (assume times is sorted) - idx = searchsortedlast(itp.times, t) - - # Handle out-of-bounds - if idx == 0 - return _get_field!(itp, 1)(x) # clamp to start - elseif idx >= length(itp.times) - return _get_field!(itp, length(itp.times))(x) # clamp to end - end - - t1 = itp.times[idx] - t2 = itp.times[idx + 1] - - w = (t - t1) / (t2 - t1) # linear weights - - # Load fields (lazily) - f1 = _get_field!(itp, idx) - f2 = _get_field!(itp, idx + 1) - - return (1 - w) * f1(x) + w * f2(x) -end - -function _get_field!(itp::LazyTimeInterpolator, idx::Int) - return lock(itp.lock) do - if !haskey(itp.buffer, idx) - # Remove far-away indices - filter!(p -> abs(p.first - idx) <= 1, itp.buffer) - - field = itp.loader(idx) - itp.buffer[idx] = field - end - return itp.buffer[idx] - end + interp = Interpolations.scale(itp, gridx) + return FieldInterpolator1D(TupleCallAdaptor(interp), dir) end diff --git a/test/Project.toml b/test/Project.toml index e65c793d5..f9cfabc80 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Magnetostatics = "e551bd30-f216-4325-9989-532520624021" diff --git a/test/runtests.jl b/test/runtests.jl index 65a68ecff..e39c51ea1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -492,3 +492,5 @@ include("test_gc.jl") include("test_hybrid.jl") include("test_boris_kernel.jl") + +include("test_interpolations_backend.jl") diff --git a/test/test_interpolations_backend.jl b/test/test_interpolations_backend.jl new file mode 100644 index 000000000..9a2e045c2 --- /dev/null +++ b/test/test_interpolations_backend.jl @@ -0,0 +1,105 @@ +module test_interpolations_backend + +using Test +using TestParticle +using StaticArrays +import TestParticle as TP + +@testset "InterpolationsBackend" begin + backend = InterpolationsBackend() + + @testset "Cartesian 3D scalar" begin + x = range(-10, 10, length = 4) + y = range(-10, 10, length = 6) + z = range(-10, 10, length = 8) + n = Float64[ + i + j + k + for i in eachindex(x), j in eachindex(y), k in eachindex(z) + ] + + # Periodic BC: out-of-domain point should return a finite wrapped value + nfunc = TP.build_interpolator(backend, TP.CartesianGrid, n, x, y, z, 1, 2) + @test isfinite(nfunc(SA[20, 0, 0])) + + # Flat BC: out-of-domain point should clamp to boundary value + nfunc_flat = TP.build_interpolator(backend, TP.CartesianGrid, n, x, y, z, 1, 3) + @test nfunc_flat(SA[20, 0, 0]) ≈ 12.0 + + # Interior point: exact for linear interpolation + nfunc_int = TP.build_interpolator(backend, TP.CartesianGrid, n, x, y, z, 1, 3) + @test nfunc_int(SA[0, 0, 0]) ≈ 10.5 atol = 0.5 + + # Higher orders: just verify they run and return finite values + nfunc2 = TP.build_interpolator(backend, TP.CartesianGrid, n, x, y, z, 2, 3) + @test isfinite(nfunc2(SA[0, 0, 0])) + + nfunc3 = TP.build_interpolator(backend, TP.CartesianGrid, n, x, y, z, 3, 3) + @test isfinite(nfunc3(SA[0, 0, 0])) + end + + @testset "Cartesian 3D vector" begin + x = range(-10, 10, length = 5) + y = range(-10, 10, length = 5) + z = range(-10, 10, length = 5) + B = fill(0.0, 3, length(x), length(y), length(z)) + B[3, :, :, :] .= 1.0 + + Bfunc = TP.build_interpolator(backend, TP.CartesianGrid, B, x, y, z) + @test Bfunc(SA[0.0, 0.0, 0.0]) ≈ SA[0.0, 0.0, 1.0] + end + + @testset "RectilinearGrid" begin + x = [0.0, 1.0, 4.0, 9.0] + y = [0.0, 1.0, 4.0, 9.0] + z = [0.0, 1.0, 4.0, 9.0] + A = [i + j + k for i in x, j in y, k in z] + + Afunc = TP.build_interpolator(backend, TP.RectilinearGrid, A, x, y, z) + @test Afunc(SA[4.0, 4.0, 4.0]) ≈ 12.0 + @test Afunc(SA[1.0, 4.0, 9.0]) ≈ 14.0 + end + + @testset "Spherical 3D" begin + r = range(0.1, 10, length = 11) + θ = range(0, π, length = 11) + ϕ = range(0, 2π, length = 11) + B = fill(0.0, 3, length(r), length(θ), length(ϕ)) + B[1, :, :, :] .= 1.0 + + Bfunc = TP.build_interpolator(backend, TP.StructuredGrid, B, r, θ, ϕ) + @test Bfunc(SA[1, 1, 1]) ≈ [0.57735, 0.57735, 0.57735] atol = 1.0e-5 + + A = ones(length(r), length(θ), length(ϕ)) + Afunc = TP.build_interpolator(backend, TP.StructuredGrid, A, r, θ, ϕ) + @test Afunc(SA[1, 1, 1]) == 1.0 + end + + @testset "prepare kwarg" begin + x = range(-5, 5, length = 10) + y = range(-5, 5, length = 10) + z = range(-5, 5, length = 10) + B = fill(0.0, 3, length(x), length(y), length(z)) + B[3, :, :, :] .= 1.0 + E = fill(0.0, 3, length(x), length(y), length(z)) + + param = prepare(x, y, z, E, B; backend = InterpolationsBackend()) + @test param[4](SA[0.0, 0.0, 0.0]) ≈ SA[0.0, 0.0, 1.0] + end + + @testset "Results match FastInterpolations" begin + x = range(-5, 5, length = 10) + y = range(-5, 5, length = 10) + z = range(-5, 5, length = 10) + n = Float64[ + sin(i) + cos(j) + i * j * k + for i in x, j in y, k in z + ] + + pt = SA[1.0, -1.0, 2.0] + fi_val = TP.build_interpolator(TP.CartesianGrid, n, x, y, z)(pt) + int_val = TP.build_interpolator(backend, TP.CartesianGrid, n, x, y, z)(pt) + @test fi_val ≈ int_val atol = 1.0e-4 + end +end + +end # module test_interpolations_backend From cdc09a269b2f0e8019aa3ddf992ba6f793412f4e Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Fri, 6 Mar 2026 13:26:30 -0500 Subject: [PATCH 24/27] Revert "Merge branch 'fast-interpolations' of https://github.com/henry2004y/TestParticle.jl into fast-interpolations" This reverts commit b1220a4236618b6a175eeeae4ae0c0725303b47e, reversing changes made to 99267912c4c27ae72f9cdc7a5abb1868306ecbac. --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 660813d64..c691c80b1 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" FastInterpolations = "9ea80cae-fc13-4c00-8066-6eaedb12f34b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa" From d69d59481e536eb702557175a27e548f6e68241c Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Fri, 6 Mar 2026 13:26:36 -0500 Subject: [PATCH 25/27] Revert "Experimenting backend support for both Interpolations.jl and FastInterpolations.jl" This reverts commit 99267912c4c27ae72f9cdc7a5abb1868306ecbac. --- Project.toml | 2 - src/TestParticle.jl | 4 - src/prepare.jl | 8 +- src/utility/fastinterpolation.jl | 48 ++-- src/utility/interp_backends.jl | 15 -- src/utility/interpolation.jl | 339 +++++++++++++++++++++------- test/Project.toml | 1 - test/runtests.jl | 2 - test/test_interpolations_backend.jl | 105 --------- 9 files changed, 284 insertions(+), 240 deletions(-) delete mode 100644 src/utility/interp_backends.jl delete mode 100644 test/test_interpolations_backend.jl diff --git a/Project.toml b/Project.toml index c691c80b1..2a1451034 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,6 @@ DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" FastInterpolations = "9ea80cae-fc13-4c00-8066-6eaedb12f34b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa" @@ -41,7 +40,6 @@ DiffResults = "1" Distributed = "1" FastInterpolations = "0.4" ForwardDiff = "1" -Interpolations = "0.16" KernelAbstractions = "0.9" LinearAlgebra = "1" Meshes = "0.55, 0.56" diff --git a/src/TestParticle.jl b/src/TestParticle.jl index fc56399d0..1d42727c5 100644 --- a/src/TestParticle.jl +++ b/src/TestParticle.jl @@ -3,7 +3,6 @@ module TestParticle using LinearAlgebra: norm, ×, ⋅, diag, normalize using FastInterpolations: linear_interp, quadratic_interp, cubic_interp, constant_interp, Extrap, NoExtrap, PeriodicBC, ZeroCurvBC -import Interpolations using SciMLBase: AbstractODEProblem, AbstractODEFunction, AbstractODESolution, ReturnCode, BasicEnsembleAlgorithm, EnsembleThreads, EnsembleSerial, EnsembleDistributed, DEFAULT_SPECIALIZATION, ODEFunction, ODEProblem, remake, @@ -41,15 +40,12 @@ export get_gyrofrequency, export orbit, monitor export get_fields, get_work export LazyTimeInterpolator -export AbstractInterpolationBackend, FastInterpolationsBackend, InterpolationsBackend export TraceProblem, TraceGCProblem, TraceHybridProblem, CartesianGrid, RectilinearGrid, StructuredGrid export EnsembleSerial, EnsembleThreads, EnsembleDistributed, remake include("types.jl") include("utility/utility.jl") -include("utility/interp_backends.jl") include("utility/fastinterpolation.jl") -include("utility/interpolation.jl") include("sampler.jl") include("prepare.jl") include("gc.jl") diff --git a/src/prepare.jl b/src/prepare.jl index 172d2eff4..01ff73475 100644 --- a/src/prepare.jl +++ b/src/prepare.jl @@ -64,12 +64,8 @@ get_EField(param) = param[3] prepare_field(f, args...; kwargs...) = Field(f) prepare_field(f::ZeroField, args...; kwargs...) = f -function prepare_field( - f::AbstractArray, x...; - gridtype, order, bc, - backend::AbstractInterpolationBackend = FastInterpolationsBackend(), kw... - ) - return Field(build_interpolator(backend, gridtype, f, x..., order, bc; kw...)) +function prepare_field(f::AbstractArray, x...; gridtype, order, bc, kw...) + return Field(build_interpolator(gridtype, f, x..., order, bc; kw...)) end function _prepare( diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl index cc3ea1fba..9a21a914f 100644 --- a/src/utility/fastinterpolation.jl +++ b/src/utility/fastinterpolation.jl @@ -38,6 +38,7 @@ struct FieldInterpolator2D{T} <: AbstractFieldInterpolator end @inbounds function (fi::FieldInterpolator2D)(xu) + # 2D interpolation usually involves x and y return fi.itp((xu[1], xu[2])) end @@ -80,6 +81,7 @@ function (fi::SphericalFieldInterpolator)(xu) rθϕ = cart2sph(xu) res = fi.itp(rθϕ) if length(res) > 1 + # Convert vector result from spherical to cartesian basis Br, Bθ, Bϕ = res return sph_to_cart_vector(Br, Bθ, Bϕ, rθϕ[2], rθϕ[3]) else @@ -107,25 +109,26 @@ function _get_extrap_mode(bc) end function _fastinterp(grids, A, order, bc) + T_A = eltype(A) + T_F = T_A <: SVector ? eltype(T_A) : T_A + T_F = T_F <: AbstractFloat ? T_F : Float64 + matched_grids = map(g -> _match_grid_type(g, T_F), grids) + extrap_mode = _get_extrap_mode(bc) if order == 1 - return linear_interp(grids, A; extrap = extrap_mode) + return linear_interp(matched_grids, A; extrap = extrap_mode) elseif order == 2 - return quadratic_interp(grids, A; extrap = extrap_mode) + return quadratic_interp(matched_grids, A; extrap = extrap_mode) elseif order == 3 return cubic_interp(grids, A; extrap = extrap_mode) else - return constant_interp(grids, A; extrap = extrap_mode) + return constant_interp(matched_grids, A; extrap = extrap_mode) end end -@inline build_interpolator(A, grid1, args...) = - build_interpolator(FastInterpolationsBackend(), CartesianGrid, A, grid1, args...) -@inline build_interpolator(gridtype::Type, A, args...) = - build_interpolator(FastInterpolationsBackend(), gridtype, A, args...) +@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) """ - build_interpolator(backend, gridtype, A, grids..., order::Int=1, bc::Int=3) build_interpolator(gridtype, A, grids..., order::Int=1, bc::Int=3) build_interpolator(A, grids..., order::Int=1, bc::Int=3) @@ -133,7 +136,6 @@ Return a function for interpolating field array `A` on the given grids. # Arguments - - `backend`: `FastInterpolationsBackend()` (default) or `InterpolationsBackend()`. - `gridtype`: `CartesianGrid`, `RectilinearGrid` or `StructuredGrid`. Usually determined by the number of grids. - `A`: field array. For vector field, the first dimension should be 3 if it's not an SVector wrapper. - `order::Int=1`: order of interpolation in [1,2,3]. @@ -143,17 +145,17 @@ Return a function for interpolating field array `A` on the given grids. The input array `A` may be modified in-place for memory optimization. """ function build_interpolator( - b::FastInterpolationsBackend, ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, + ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 3 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) - return build_interpolator(b, CartesianGrid, As, gridx, gridy, gridz, order, bc) + return build_interpolator(CartesianGrid, As, gridx, gridy, gridz, order, bc) end function build_interpolator( - ::FastInterpolationsBackend, ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, + ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 3 ) where {T} @@ -161,21 +163,22 @@ function build_interpolator( @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." end itp = _fastinterp((gridx, gridy, gridz), A, order, bc) + # Return field value at a given location. return FieldInterpolator(itp) end function build_interpolator( - b::FastInterpolationsBackend, ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 3 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) - return build_interpolator(b, RectilinearGrid, As, gridx, gridy, gridz, order, bc) + return build_interpolator(RectilinearGrid, As, gridx, gridy, gridz, order, bc) end function build_interpolator( - ::FastInterpolationsBackend, ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 3 ) where {T} @@ -191,26 +194,28 @@ function build_interpolator( end function build_interpolator( - b::FastInterpolationsBackend, ::Type{<:StructuredGrid}, A::AbstractArray{T, 4}, + ::Type{<:StructuredGrid}, A::AbstractArray{T, 4}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) - return build_interpolator(b, StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) + return build_interpolator(StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) end function build_interpolator( - ::FastInterpolationsBackend, ::Type{<:StructuredGrid}, A::AbstractArray{T, 3}, + ::Type{<:StructuredGrid}, A::AbstractArray{T, 3}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." end + # Detect if uniform grid (old Spherical) or non-uniform r (old SphericalNonUniformR) + # We check if gridr is an AbstractRange, e.g. Base.LogRange is an AbstractRange but not uniform! is_uniform_r = gridr isa AbstractRange && !(gridr isa Base.LogRange) if is_uniform_r itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) - else + else # Non-uniform R (SphericalNonUniformR behavior) gridϕ, A = _ensure_full_phi(gridϕ, A) itp = _fastinterp((gridr, gridθ, gridϕ), A, order, bc) end @@ -219,7 +224,7 @@ function build_interpolator( end function build_interpolator( - ::FastInterpolationsBackend, ::Type{<:CartesianGrid}, A, + ::Type{<:CartesianGrid}, A, gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 3 ) if eltype(A) <: SVector @@ -235,7 +240,7 @@ function build_interpolator( end function build_interpolator( - ::FastInterpolationsBackend, ::Type{<:CartesianGrid}, A, gridx::AbstractVector, + ::Type{<:CartesianGrid}, A, gridx::AbstractVector, order::Int = 1, bc::Int = 3; dir = 1 ) if eltype(A) <: SVector @@ -247,6 +252,7 @@ function build_interpolator( end itp = _fastinterp((gridx,), As, order, bc) + return FieldInterpolator1D(itp, dir) end diff --git a/src/utility/interp_backends.jl b/src/utility/interp_backends.jl deleted file mode 100644 index 62e34a534..000000000 --- a/src/utility/interp_backends.jl +++ /dev/null @@ -1,15 +0,0 @@ -abstract type AbstractInterpolationBackend end - -""" - FastInterpolationsBackend - -Interpolation backend using FastInterpolations.jl (default). -""" -struct FastInterpolationsBackend <: AbstractInterpolationBackend end - -""" - InterpolationsBackend - -Interpolation backend using Interpolations.jl. -""" -struct InterpolationsBackend <: AbstractInterpolationBackend end diff --git a/src/utility/interpolation.jl b/src/utility/interpolation.jl index 99a853306..5bb31ca5d 100644 --- a/src/utility/interpolation.jl +++ b/src/utility/interpolation.jl @@ -1,75 +1,130 @@ -# Field interpolations using the Interpolations.jl backend. +# Field interpolations. """ - TupleCallAdaptor{T} + AbstractFieldInterpolator -Wraps an Interpolations.jl interpolant so it can be called with a single -tuple argument, matching the call convention used by `FieldInterpolator`. +Abstract type for all field interpolators. """ -struct TupleCallAdaptor{T} +abstract type AbstractFieldInterpolator <: Function end + +""" + FieldInterpolator{T} + +A callable struct that wraps a 3D interpolation object. +""" +struct FieldInterpolator{T} <: AbstractFieldInterpolator itp::T end -Adapt.adapt_structure(to, a::TupleCallAdaptor) = TupleCallAdaptor(Adapt.adapt(to, a.itp)) +const FieldInterpolator3D = FieldInterpolator -@inline (a::TupleCallAdaptor{T})(t::NTuple{1}) where {T} = a.itp(t[1]) -@inline (a::TupleCallAdaptor{T})(t::NTuple{2}) where {T} = a.itp(t[1], t[2]) -@inline (a::TupleCallAdaptor{T})(t::NTuple{3}) where {T} = a.itp(t[1], t[2], t[3]) -@inline (a::TupleCallAdaptor{T})(t::AbstractVector) where {T} = a.itp(t[1], t[2], t[3]) +@inbounds function (fi::FieldInterpolator)(xu) + return fi.itp(xu[1], xu[2], xu[3]) +end -function _get_bspline(order::Int, periodic::Bool) - gt = Interpolations.OnCell() - interp_type = if order == 1 - Interpolations.Linear - elseif order == 2 - Interpolations.Quadratic - elseif order == 3 - Interpolations.Cubic - else - throw(ArgumentError("Unsupported interpolation order!")) - end +function (fi::FieldInterpolator)(xu, t) + return fi(xu) +end - if periodic - return Interpolations.BSpline(interp_type(Interpolations.Periodic(gt))) - else - if interp_type == Interpolations.Linear - return Interpolations.BSpline(Interpolations.Linear()) - else - return Interpolations.BSpline(interp_type(Interpolations.Flat(gt))) - end - end +Adapt.adapt_structure(to, fi::FieldInterpolator) = FieldInterpolator(Adapt.adapt(to, fi.itp)) + +""" + FieldInterpolator2D{T} + +A callable struct that wraps a 2D interpolation object. +""" +struct FieldInterpolator2D{T} <: AbstractFieldInterpolator + itp::T end -function _get_interp_object(A, order::Int, bc::Int) - bspline = _get_bspline(order, bc == 2) +@inbounds function (fi::FieldInterpolator2D)(xu) + # 2D interpolation usually involves x and y + return fi.itp(xu[1], xu[2]) +end - bctype = if bc == 1 - if eltype(A) <: SVector - SVector{3, eltype(eltype(A))}(NaN, NaN, NaN) - else - eltype(A)(NaN) - end - elseif bc == 2 - Interpolations.Periodic() +function (fi::FieldInterpolator2D)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::FieldInterpolator2D) = FieldInterpolator2D(Adapt.adapt(to, fi.itp)) + +""" + FieldInterpolator1D{T} + +A callable struct that wraps a 1D interpolation object. +""" +struct FieldInterpolator1D{T} <: AbstractFieldInterpolator + itp::T + dir::Int +end + +@inbounds function (fi::FieldInterpolator1D)(xu) + return fi.itp(xu[fi.dir]) +end + +function (fi::FieldInterpolator1D)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::FieldInterpolator1D) = FieldInterpolator1D(Adapt.adapt(to, fi.itp), fi.dir) + +""" + SphericalFieldInterpolator{T} + +A callable struct for spherical grid interpolation (scalar or combined vector). +""" +struct SphericalFieldInterpolator{T} <: AbstractFieldInterpolator + itp::T +end + +function (fi::SphericalFieldInterpolator)(xu) + r_val, θ_val, ϕ_val = cart2sph(xu) + res = fi.itp(r_val, θ_val, ϕ_val) + if length(res) > 1 + # Convert vector result from spherical to cartesian basis + Br, Bθ, Bϕ = res + return sph_to_cart_vector(Br, Bθ, Bϕ, θ_val, ϕ_val) else - Interpolations.Flat() + return res end +end - return Interpolations.extrapolate(Interpolations.interpolate(A, bspline), bctype) +function (fi::SphericalFieldInterpolator)(xu, t) + return fi(xu) end +Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp)) + +@inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) + +""" + build_interpolator(gridtype, A, grids..., order::Int=1, bc::Int=1) + build_interpolator(A, grids..., order::Int=1, bc::Int=1) + +Return a function for interpolating field array `A` on the given grids. + +# Arguments + + - `gridtype`: `CartesianGrid`, `RectilinearGrid` or `StructuredGrid`. Usually determined by the number of grids. + - `A`: field array. For vector field, the first dimension should be 3 if it's not an SVector wrapper. + - `order::Int=1`: order of interpolation in [1,2,3]. + - `bc::Int=1`: type of boundary conditions, 1 -> NaN, 2 -> periodic, 3 -> Flat. + +# Notes +The input array `A` may be modified in-place for memory optimization. +""" function build_interpolator( - b::InterpolationsBackend, ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, + ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) - return build_interpolator(b, CartesianGrid, As, gridx, gridy, gridz, order, bc) + return build_interpolator(CartesianGrid, As, gridx, gridy, gridz, order, bc) end function build_interpolator( - ::InterpolationsBackend, ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, + ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 ) where {T} @@ -77,22 +132,24 @@ function build_interpolator( @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." end itp = _get_interp_object(A, order, bc) - interp = Interpolations.scale(itp, gridx, gridy, gridz) - return FieldInterpolator(TupleCallAdaptor(interp)) + interp = scale(itp, gridx, gridy, gridz) + + # Return field value at a given location. + return FieldInterpolator(interp) end function build_interpolator( - b::InterpolationsBackend, ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) - return build_interpolator(b, RectilinearGrid, As, gridx, gridy, gridz, order, bc) + return build_interpolator(RectilinearGrid, As, gridx, gridy, gridz, order, bc) end function build_interpolator( - ::InterpolationsBackend, ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, + ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, order::Int = 1, bc::Int = 1 ) where {T} @@ -110,29 +167,27 @@ function build_interpolator( T(NaN) end elseif bc == 2 - Interpolations.Periodic() + Periodic() else - Interpolations.Flat() + Flat() end - itp = Interpolations.extrapolate( - Interpolations.interpolate!((gridx, gridy, gridz), A, Interpolations.Gridded(Interpolations.Linear())), - bctype - ) - return FieldInterpolator(TupleCallAdaptor(itp)) + itp = extrapolate(interpolate!((gridx, gridy, gridz), A, Gridded(Linear())), bctype) + + return FieldInterpolator(itp) end function build_interpolator( - b::InterpolationsBackend, ::Type{<:StructuredGrid}, A::AbstractArray{T, 4}, + ::Type{<:StructuredGrid}, A::AbstractArray{T, 4}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) - return build_interpolator(b, StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) + return build_interpolator(StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) end function build_interpolator( - ::InterpolationsBackend, ::Type{<:StructuredGrid}, A::AbstractArray{T, 3}, + ::Type{<:StructuredGrid}, A::AbstractArray{T, 3}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 ) where {T} if eltype(A) <: SVector @@ -150,43 +205,40 @@ function build_interpolator( has_2pi = isapprox(ϕ_max, 2π, atol = 1.0e-5) ϕ_bc = if has_0 && has_2pi - Interpolations.Periodic(Interpolations.OnGrid()) + Periodic(OnGrid()) else - Interpolations.Periodic(Interpolations.OnCell()) + Periodic(OnCell()) end - bctype = (Interpolations.Flat(), Interpolations.Flat(), ϕ_bc) + bctype = (Flat(), Flat(), ϕ_bc) if order == 1 - itp = Interpolations.extrapolate( - Interpolations.interpolate!( - (gridr, gridθ, gridϕ), A, - Interpolations.Gridded(Interpolations.Linear()) - ), - bctype - ) + itp = extrapolate(interpolate!((gridr, gridθ, gridϕ), A, Gridded(Linear())), bctype) else interp_type = if order == 2 - Interpolations.Quadratic + Quadratic elseif order == 3 - Interpolations.Cubic + Cubic else throw(ArgumentError("Unsupported interpolation order!")) end itp_type = ( - Interpolations.BSpline(interp_type(Interpolations.Flat(Interpolations.OnCell()))), - Interpolations.BSpline(interp_type(Interpolations.Flat(Interpolations.OnCell()))), - Interpolations.BSpline(interp_type(ϕ_bc)), + BSpline(interp_type(Flat(OnCell()))), + BSpline(interp_type(Flat(OnCell()))), + BSpline(interp_type(ϕ_bc)), ) - itp_obj = Interpolations.extrapolate(Interpolations.interpolate(A, itp_type), bctype) - itp = Interpolations.scale(itp_obj, gridr, gridθ, gridϕ) + itp_obj = extrapolate(interpolate(A, itp_type), bctype) + itp = scale(itp_obj, gridr, gridθ, gridϕ) end - return SphericalFieldInterpolator(TupleCallAdaptor(itp)) + bctype = (Flat(), Flat(), phi_bc) + itp = extrapolate(interpolate!((gridr, gridθ, gridϕ), A, Gridded(Linear())), bctype) + + return SphericalFieldInterpolator(itp) end function build_interpolator( - ::InterpolationsBackend, ::Type{<:CartesianGrid}, A, + ::Type{<:CartesianGrid}, A, gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 2 ) if eltype(A) <: SVector @@ -198,12 +250,13 @@ function build_interpolator( end itp = _get_interp_object(As, order, bc) - interp = Interpolations.scale(itp, gridx, gridy) - return FieldInterpolator2D(TupleCallAdaptor(interp)) + interp = scale(itp, gridx, gridy) + + return FieldInterpolator2D(interp) end function build_interpolator( - ::InterpolationsBackend, ::Type{<:CartesianGrid}, A, gridx::AbstractVector, + ::Type{<:CartesianGrid}, A, gridx::AbstractVector, order::Int = 1, bc::Int = 3; dir = 1 ) if eltype(A) <: SVector @@ -215,6 +268,124 @@ function build_interpolator( end itp = _get_interp_object(As, order, bc) - interp = Interpolations.scale(itp, gridx) - return FieldInterpolator1D(TupleCallAdaptor(interp), dir) + interp = scale(itp, gridx) + + return FieldInterpolator1D(interp, dir) +end + +# Internal Helpers + +function _get_bspline(order::Int, periodic::Bool) + gt = OnCell() + + interp_type = if order == 1 + Linear + elseif order == 2 + Quadratic + elseif order == 3 + Cubic + else + throw(ArgumentError("Unsupported interpolation order!")) + end + + if periodic + return BSpline(interp_type(Periodic(gt))) + else + # Linear() is special as it doesn't take an argument. + if interp_type == Linear + return BSpline(Linear()) + else + return BSpline(interp_type(Flat(gt))) + end + end +end + +function _get_interp_object(A, order::Int, bc::Int) + bspline = _get_bspline(order, bc == 2) + + bctype = if bc == 1 + if eltype(A) <: SVector + SVector{3, eltype(eltype(A))}(NaN, NaN, NaN) + else + eltype(eltype(A))(NaN) + end + elseif bc == 2 + Periodic() + else + Flat() + end + + return extrapolate(interpolate(A, bspline), bctype) +end + + +# Time-dependent field interpolation. + +""" + LazyTimeInterpolator{T, F, L} + +A callable struct for handling time-dependent fields with lazy loading and linear time interpolation. + +# Fields + +- `times::Vector{T}`: Sorted vector of time points. +- `loader::L`: Function `i -> field` that loads the field at index `i`. +- `buffer::Dict{Int, F}`: Cache for loaded fields. +- `lock::ReentrantLock`: Lock for thread safety. +""" +struct LazyTimeInterpolator{T, F, L} <: Function + times::Vector{T} + loader::L + buffer::Dict{Int, F} + lock::ReentrantLock +end + +function LazyTimeInterpolator(times::AbstractVector, loader::Function) + # Determine the field type by loading the first field + f1 = loader(1) + return _LazyTimeInterpolator(times, loader, f1) +end + +function _LazyTimeInterpolator(times::AbstractVector, loader::Function, f1::F) where {F} + buffer = Dict{Int, F}(1 => f1) + lock = ReentrantLock() + return LazyTimeInterpolator{eltype(times), F, typeof(loader)}( + times, loader, buffer, lock + ) +end + +function (itp::LazyTimeInterpolator)(x, t) + # Find the time interval [t1, t2] such that t1 <= t <= t2 (assume times is sorted) + idx = searchsortedlast(itp.times, t) + + # Handle out-of-bounds + if idx == 0 + return _get_field!(itp, 1)(x) # clamp to start + elseif idx >= length(itp.times) + return _get_field!(itp, length(itp.times))(x) # clamp to end + end + + t1 = itp.times[idx] + t2 = itp.times[idx + 1] + + w = (t - t1) / (t2 - t1) # linear weights + + # Load fields (lazily) + f1 = _get_field!(itp, idx) + f2 = _get_field!(itp, idx + 1) + + return (1 - w) * f1(x) + w * f2(x) +end + +function _get_field!(itp::LazyTimeInterpolator, idx::Int) + return lock(itp.lock) do + if !haskey(itp.buffer, idx) + # Remove far-away indices + filter!(p -> abs(p.first - idx) <= 1, itp.buffer) + + field = itp.loader(idx) + itp.buffer[idx] = field + end + return itp.buffer[idx] + end end diff --git a/test/Project.toml b/test/Project.toml index f9cfabc80..e65c793d5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,5 @@ [deps] Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" -Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Magnetostatics = "e551bd30-f216-4325-9989-532520624021" diff --git a/test/runtests.jl b/test/runtests.jl index e39c51ea1..65a68ecff 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -492,5 +492,3 @@ include("test_gc.jl") include("test_hybrid.jl") include("test_boris_kernel.jl") - -include("test_interpolations_backend.jl") diff --git a/test/test_interpolations_backend.jl b/test/test_interpolations_backend.jl deleted file mode 100644 index 9a2e045c2..000000000 --- a/test/test_interpolations_backend.jl +++ /dev/null @@ -1,105 +0,0 @@ -module test_interpolations_backend - -using Test -using TestParticle -using StaticArrays -import TestParticle as TP - -@testset "InterpolationsBackend" begin - backend = InterpolationsBackend() - - @testset "Cartesian 3D scalar" begin - x = range(-10, 10, length = 4) - y = range(-10, 10, length = 6) - z = range(-10, 10, length = 8) - n = Float64[ - i + j + k - for i in eachindex(x), j in eachindex(y), k in eachindex(z) - ] - - # Periodic BC: out-of-domain point should return a finite wrapped value - nfunc = TP.build_interpolator(backend, TP.CartesianGrid, n, x, y, z, 1, 2) - @test isfinite(nfunc(SA[20, 0, 0])) - - # Flat BC: out-of-domain point should clamp to boundary value - nfunc_flat = TP.build_interpolator(backend, TP.CartesianGrid, n, x, y, z, 1, 3) - @test nfunc_flat(SA[20, 0, 0]) ≈ 12.0 - - # Interior point: exact for linear interpolation - nfunc_int = TP.build_interpolator(backend, TP.CartesianGrid, n, x, y, z, 1, 3) - @test nfunc_int(SA[0, 0, 0]) ≈ 10.5 atol = 0.5 - - # Higher orders: just verify they run and return finite values - nfunc2 = TP.build_interpolator(backend, TP.CartesianGrid, n, x, y, z, 2, 3) - @test isfinite(nfunc2(SA[0, 0, 0])) - - nfunc3 = TP.build_interpolator(backend, TP.CartesianGrid, n, x, y, z, 3, 3) - @test isfinite(nfunc3(SA[0, 0, 0])) - end - - @testset "Cartesian 3D vector" begin - x = range(-10, 10, length = 5) - y = range(-10, 10, length = 5) - z = range(-10, 10, length = 5) - B = fill(0.0, 3, length(x), length(y), length(z)) - B[3, :, :, :] .= 1.0 - - Bfunc = TP.build_interpolator(backend, TP.CartesianGrid, B, x, y, z) - @test Bfunc(SA[0.0, 0.0, 0.0]) ≈ SA[0.0, 0.0, 1.0] - end - - @testset "RectilinearGrid" begin - x = [0.0, 1.0, 4.0, 9.0] - y = [0.0, 1.0, 4.0, 9.0] - z = [0.0, 1.0, 4.0, 9.0] - A = [i + j + k for i in x, j in y, k in z] - - Afunc = TP.build_interpolator(backend, TP.RectilinearGrid, A, x, y, z) - @test Afunc(SA[4.0, 4.0, 4.0]) ≈ 12.0 - @test Afunc(SA[1.0, 4.0, 9.0]) ≈ 14.0 - end - - @testset "Spherical 3D" begin - r = range(0.1, 10, length = 11) - θ = range(0, π, length = 11) - ϕ = range(0, 2π, length = 11) - B = fill(0.0, 3, length(r), length(θ), length(ϕ)) - B[1, :, :, :] .= 1.0 - - Bfunc = TP.build_interpolator(backend, TP.StructuredGrid, B, r, θ, ϕ) - @test Bfunc(SA[1, 1, 1]) ≈ [0.57735, 0.57735, 0.57735] atol = 1.0e-5 - - A = ones(length(r), length(θ), length(ϕ)) - Afunc = TP.build_interpolator(backend, TP.StructuredGrid, A, r, θ, ϕ) - @test Afunc(SA[1, 1, 1]) == 1.0 - end - - @testset "prepare kwarg" begin - x = range(-5, 5, length = 10) - y = range(-5, 5, length = 10) - z = range(-5, 5, length = 10) - B = fill(0.0, 3, length(x), length(y), length(z)) - B[3, :, :, :] .= 1.0 - E = fill(0.0, 3, length(x), length(y), length(z)) - - param = prepare(x, y, z, E, B; backend = InterpolationsBackend()) - @test param[4](SA[0.0, 0.0, 0.0]) ≈ SA[0.0, 0.0, 1.0] - end - - @testset "Results match FastInterpolations" begin - x = range(-5, 5, length = 10) - y = range(-5, 5, length = 10) - z = range(-5, 5, length = 10) - n = Float64[ - sin(i) + cos(j) + i * j * k - for i in x, j in y, k in z - ] - - pt = SA[1.0, -1.0, 2.0] - fi_val = TP.build_interpolator(TP.CartesianGrid, n, x, y, z)(pt) - int_val = TP.build_interpolator(backend, TP.CartesianGrid, n, x, y, z)(pt) - @test fi_val ≈ int_val atol = 1.0e-4 - end -end - -end # module test_interpolations_backend From 3eb01c91b6beb5f525a629d35776e829fa6179fd Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Fri, 6 Mar 2026 18:18:06 -0500 Subject: [PATCH 26/27] Update to FastInterpolations.jl v0.4.1 --- Project.toml | 2 +- src/TestParticle.jl | 2 +- src/prepare.jl | 2 +- src/utility/fastinterpolation.jl | 48 +++++++++++++++----------------- src/utility/interpolation.jl | 4 +-- 5 files changed, 27 insertions(+), 31 deletions(-) diff --git a/Project.toml b/Project.toml index 2a1451034..da4164895 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,7 @@ Adapt = "4.4" ChunkSplitters = "3" DiffResults = "1" Distributed = "1" -FastInterpolations = "0.4" +FastInterpolations = "0.4.1" ForwardDiff = "1" KernelAbstractions = "0.9" LinearAlgebra = "1" diff --git a/src/TestParticle.jl b/src/TestParticle.jl index 1d42727c5..446ab712b 100644 --- a/src/TestParticle.jl +++ b/src/TestParticle.jl @@ -2,7 +2,7 @@ module TestParticle using LinearAlgebra: norm, ×, ⋅, diag, normalize using FastInterpolations: linear_interp, quadratic_interp, cubic_interp, constant_interp, - Extrap, NoExtrap, PeriodicBC, ZeroCurvBC + Extrap, NoExtrap, PeriodicBC, ZeroCurvBC, FillExtrap using SciMLBase: AbstractODEProblem, AbstractODEFunction, AbstractODESolution, ReturnCode, BasicEnsembleAlgorithm, EnsembleThreads, EnsembleSerial, EnsembleDistributed, DEFAULT_SPECIALIZATION, ODEFunction, ODEProblem, remake, diff --git a/src/prepare.jl b/src/prepare.jl index 01ff73475..4dacd209f 100644 --- a/src/prepare.jl +++ b/src/prepare.jl @@ -148,7 +148,7 @@ function prepare( return _prepare(E, B, F, x, y, z; gridtype = StructuredGrid, order, bc, kw...) end -function prepare(x::AbstractVector, E, B, F = ZeroField(); order = 1, bc = 3, dir = 1, kw...) +function prepare(x::AbstractVector, E, B, F = ZeroField(); order = 1, bc = 1, dir = 1, kw...) @assert issorted(x) "Grid vector `x` must be sorted." return _prepare(E, B, F, x; gridtype = CartesianGrid, order, bc, dir, kw...) end diff --git a/src/utility/fastinterpolation.jl b/src/utility/fastinterpolation.jl index 9a21a914f..e7e64fd03 100644 --- a/src/utility/fastinterpolation.jl +++ b/src/utility/fastinterpolation.jl @@ -95,42 +95,38 @@ end Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp)) -function _get_extrap_mode(bc) +function _get_extrap_mode(bc, T::Type) if bc == 2 return Extrap(:wrap) elseif bc == 3 - return Extrap(:constant) + return Extrap(:clamp) else - # TODO: bc == 1 (NaN outside domain) requires native FastInterpolations.jl support. - # Once available, replace this with the appropriate Extrap mode and remove the - # manual bounds-checking code that was previously in the FieldInterpolator call methods. - return NoExtrap() + if T <: SVector + return FillExtrap(SVector{3, eltype(T)}(NaN, NaN, NaN)) + else + return FillExtrap(T(NaN)) + end end end function _fastinterp(grids, A, order, bc) - T_A = eltype(A) - T_F = T_A <: SVector ? eltype(T_A) : T_A - T_F = T_F <: AbstractFloat ? T_F : Float64 - matched_grids = map(g -> _match_grid_type(g, T_F), grids) - - extrap_mode = _get_extrap_mode(bc) + extrap_mode = _get_extrap_mode(bc, eltype(A)) if order == 1 - return linear_interp(matched_grids, A; extrap = extrap_mode) + return linear_interp(grids, A; extrap = extrap_mode) elseif order == 2 - return quadratic_interp(matched_grids, A; extrap = extrap_mode) + return quadratic_interp(grids, A; extrap = extrap_mode) elseif order == 3 return cubic_interp(grids, A; extrap = extrap_mode) else - return constant_interp(matched_grids, A; extrap = extrap_mode) + return constant_interp(grids, A; extrap = extrap_mode) end end @inline build_interpolator(A, grid1, args...) = build_interpolator(CartesianGrid, A, grid1, args...) """ - build_interpolator(gridtype, A, grids..., order::Int=1, bc::Int=3) - build_interpolator(A, grids..., order::Int=1, bc::Int=3) + build_interpolator(gridtype, A, grids..., order::Int=1, bc::Int=1) + build_interpolator(A, grids..., order::Int=1, bc::Int=1) Return a function for interpolating field array `A` on the given grids. @@ -139,7 +135,7 @@ Return a function for interpolating field array `A` on the given grids. - `gridtype`: `CartesianGrid`, `RectilinearGrid` or `StructuredGrid`. Usually determined by the number of grids. - `A`: field array. For vector field, the first dimension should be 3 if it's not an SVector wrapper. - `order::Int=1`: order of interpolation in [1,2,3]. - - `bc::Int=3`: type of boundary conditions, 1 -> NaN (not yet native; requires FastInterpolations support), 2 -> periodic, 3 -> Flat. + - `bc::Int=1`: type of boundary conditions, 1 -> NaN, 2 -> periodic, 3 -> Clamp (flat extrapolation). # Notes The input array `A` may be modified in-place for memory optimization. @@ -147,7 +143,7 @@ The input array `A` may be modified in-place for memory optimization. function build_interpolator( ::Type{<:CartesianGrid}, A::AbstractArray{T, 4}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, - order::Int = 1, bc::Int = 3 + order::Int = 1, bc::Int = 1 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) @@ -157,7 +153,7 @@ end function build_interpolator( ::Type{<:CartesianGrid}, A::AbstractArray{T, 3}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, - order::Int = 1, bc::Int = 3 + order::Int = 1, bc::Int = 1 ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." @@ -170,7 +166,7 @@ end function build_interpolator( ::Type{<:RectilinearGrid}, A::AbstractArray{T, 4}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, - order::Int = 1, bc::Int = 3 + order::Int = 1, bc::Int = 1 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) @@ -180,7 +176,7 @@ end function build_interpolator( ::Type{<:RectilinearGrid}, A::AbstractArray{T, 3}, gridx::AbstractVector, gridy::AbstractVector, gridz::AbstractVector, - order::Int = 1, bc::Int = 3 + order::Int = 1, bc::Int = 1 ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." @@ -195,7 +191,7 @@ end function build_interpolator( ::Type{<:StructuredGrid}, A::AbstractArray{T, 4}, - gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 + gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 ) where {T} @assert size(A, 1) == 3 && ndims(A) == 4 "Inconsistent 3D force field and grid!" As = reinterpret(reshape, SVector{3, T}, A) @@ -204,7 +200,7 @@ end function build_interpolator( ::Type{<:StructuredGrid}, A::AbstractArray{T, 3}, - gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 3 + gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 ) where {T} if eltype(A) <: SVector @assert ndims(A) == 3 "Inconsistent 3D force field and grid! Expected 3D array of SVectors." @@ -225,7 +221,7 @@ end function build_interpolator( ::Type{<:CartesianGrid}, A, - gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 3 + gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 1 ) if eltype(A) <: SVector @assert ndims(A) == 2 "Inconsistent 2D force field and grid! Expected 2D array of SVectors." @@ -241,7 +237,7 @@ end function build_interpolator( ::Type{<:CartesianGrid}, A, gridx::AbstractVector, - order::Int = 1, bc::Int = 3; dir = 1 + order::Int = 1, bc::Int = 1; dir = 1 ) if eltype(A) <: SVector @assert ndims(A) == 1 "Inconsistent 1D force field and grid! Expected 1D array of SVectors." diff --git a/src/utility/interpolation.jl b/src/utility/interpolation.jl index 5bb31ca5d..e9ad141a0 100644 --- a/src/utility/interpolation.jl +++ b/src/utility/interpolation.jl @@ -239,7 +239,7 @@ end function build_interpolator( ::Type{<:CartesianGrid}, A, - gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 2 + gridx::AbstractVector, gridy::AbstractVector, order::Int = 1, bc::Int = 1 ) if eltype(A) <: SVector @assert ndims(A) == 2 "Inconsistent 2D force field and grid! Expected 2D array of SVectors." @@ -257,7 +257,7 @@ end function build_interpolator( ::Type{<:CartesianGrid}, A, gridx::AbstractVector, - order::Int = 1, bc::Int = 3; dir = 1 + order::Int = 1, bc::Int = 1; dir = 1 ) if eltype(A) <: SVector @assert ndims(A) == 1 "Inconsistent 1D force field and grid! Expected 1D array of SVectors." From 9606b4537adf1cca1fbc73e459551b99e4065c0d Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Fri, 6 Mar 2026 18:54:25 -0500 Subject: [PATCH 27/27] Recover tests --- test/test_utility.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/test_utility.jl b/test/test_utility.jl index e1ea6453f..593d02f20 100644 --- a/test/test_utility.jl +++ b/test/test_utility.jl @@ -49,23 +49,23 @@ import TestParticle as TP Float32(i + j + k) for i in eachindex(x), j in eachindex(y), k in eachindex(z) ] - # TODO: re-enable when FastInterpolations.jl natively supports bc=1 (NaN outside domain). - # nfunc11 = TP.build_interpolator(n, x, y, z) - # @test nfunc11(SA[9, 0, 0]) ≈ 11.85f0 + nfunc11 = TP.build_interpolator(n, x, y, z) + @test isnan(nfunc11(SA[20, 0, 0])) + @test nfunc11(SA[9, 0, 0]) ≈ 11.85f0 nfunc12 = TP.build_interpolator(n, x, y, z, 1, 2) @test nfunc12(SA[20, 0, 0]) ≈ 10.5f0 nfunc13 = TP.build_interpolator(n, x, y, z, 1, 3) @test nfunc13(SA[20, 0, 0]) ≈ 12.0f0 - # TODO: re-enable when FastInterpolations.jl natively supports bc=1 (NaN outside domain). - # nfunc21 = TP.build_interpolator(n, x, y, z, 2) - # @test nfunc21(SA[9, 0, 0]) ≈ 11.85f0 + nfunc21 = TP.build_interpolator(n, x, y, z, 2) + @test isnan(nfunc21(SA[20, 0, 0])) + @test nfunc21(SA[9, 0, 0]) ≈ 11.85f0 nfunc22 = TP.build_interpolator(n, x, y, z, 2, 2) @test nfunc22(SA[20, 0, 0]) ≈ 10.5f0 nfunc23 = TP.build_interpolator(n, x, y, z, 2, 3) @test nfunc23(SA[20, 0, 0]) ≈ 12.0f0 - # TODO: re-enable when FastInterpolations.jl natively supports bc=1 (NaN outside domain). - # nfunc31 = TP.build_interpolator(n, x, y, z, 3) - # @test nfunc31(SA[9, 0, 0]) ≈ 11.85f0 + nfunc31 = TP.build_interpolator(n, x, y, z, 3) + @test isnan(nfunc31(SA[20, 0, 0])) + @test nfunc31(SA[9, 0, 0]) ≈ 11.85f0 nfunc32 = TP.build_interpolator(n, x, y, z, 3, 2) @test nfunc32(SA[20, 0, 0]) ≈ 10.5f0 nfunc33 = TP.build_interpolator(n, x, y, z, 3, 3) @@ -85,7 +85,7 @@ import TestParticle as TP A = ones(length(r), length(θ), length(ϕ)) Afunc = TP.build_interpolator(TP.StructuredGrid, A, r, θ, ϕ) @test Afunc(SA[1, 1, 1]) == 1.0 - @test Afunc(SA[0, 0, 0]) == 1.0 + @test isnan(Afunc(SA[0, 0, 0])) # outside domain -> NaN with bc=1 # High order spherical interpolation Bfunc2 = TP.build_interpolator(TP.StructuredGrid, B, r, θ, ϕ, 2) @@ -112,7 +112,7 @@ import TestParticle as TP A = ones(length(r), length(θ), length(ϕ)) Afunc = TP.build_interpolator(TP.StructuredGrid, A, r, θ, ϕ) @test Afunc(SA[1, 1, 1]) == 1.0 - @test Afunc(SA[0, 0, 0]) == 1.0 + @test isnan(Afunc(SA[0, 0, 0])) # outside domain -> NaN with bc=1 end end