Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions ext/LinearSolveForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module LinearSolveForwardDiffExt

using LinearSolve
using LinearSolve: SciMLLinearSolveAlgorithm, __init, LinearVerbosity, DefaultLinearSolver,
DefaultAlgorithmChoice, defaultalg
DefaultAlgorithmChoice, defaultalg, reinit!
using LinearAlgebra
using ForwardDiff
using ForwardDiff: Dual, Partials
Expand Down Expand Up @@ -342,6 +342,38 @@ function setu!(dc::DualLinearCache, u)
partial_vals!(getfield(dc, :partials_u), u) # Update in-place
end

function SciMLBase.reinit!(cache::DualLinearCache;
A = nothing,
b = nothing,
u = nothing,
p = nothing,
reuse_precs = false)
if !isnothing(A)
setA!(cache, A)
end

if !isnothing(b)
setb!(cache, b)
end

if !isnothing(u)
setu!(cache, u)
end

if !isnothing(p)
cache.linear_cache.p=p
end

isfresh = !isnothing(A)
precsisfresh = !reuse_precs && (isfresh || !isnothing(p))
isfresh |= cache.isfresh
precsisfresh |= cache.linear_cache.precsisfresh
cache.linear_cache.isfresh = true
cache.linear_cache.precsisfresh = precsisfresh

nothing
end

function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
# If the property is A or b, also update it in the LinearCache
if sym === :A
Expand Down Expand Up @@ -390,7 +422,9 @@ partial_vals!(out, x) = map!(partial_vals, out, x) # Update in-place
nodual_value(x) = x
nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x)
nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact
nodual_value(x::AbstractArray{<:Dual}) = nodual_value!(similar(x, typeof(nodual_value(first(x)))), x)
function nodual_value(x::AbstractArray{<:Dual})
nodual_value!(similar(x, typeof(nodual_value(first(x)))), x)
end
nodual_value!(out, x) = map!(nodual_value, out, x) # Update in-place

function update_partials_list!(partial_matrix::AbstractVector{T}, list_cache) where {T}
Expand Down
8 changes: 7 additions & 1 deletion test/forwarddiff_overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ backslash_x_p = A \ b

@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9)

A[1, 1]+=2
cache = overload_x_p.cache
reinit!(cache; A = sparse(A))
overload_x_p = solve!(cache, UMFPACKFactorization())
backslash_x_p = A \ b
@test ≈ (overload_x_p, backslash_x_p, rtol = 1e-9)

# Test that GenericLU doesn't create a DualLinearCache
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
Expand Down Expand Up @@ -234,4 +240,4 @@ grad = ForwardDiff.gradient(component_linsolve, p_test)
@test grad isa Vector
@test length(grad) == 2
@test !any(isnan, grad)
@test !any(isinf, grad)
@test !any(isinf, grad)
Loading