From 81b4b4c1c9572c08d1b83a1ba597602aed308a76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BCrgen=20Fuhrmann?= Date: Fri, 28 Nov 2025 20:37:32 -0800 Subject: [PATCH] Attempt the implementation of reinit! for dual cache --- ext/LinearSolveForwardDiffExt.jl | 38 ++++++++++++++++++++++++++++++-- test/forwarddiff_overloads.jl | 8 ++++++- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index cacba0112..467c2b852 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -2,7 +2,7 @@ module LinearSolveForwardDiffExt using LinearSolve using LinearSolve: SciMLLinearSolveAlgorithm, __init, LinearVerbosity, DefaultLinearSolver, - DefaultAlgorithmChoice, defaultalg + DefaultAlgorithmChoice, defaultalg, reinit! using LinearAlgebra using ForwardDiff using ForwardDiff: Dual, Partials @@ -342,6 +342,38 @@ function setu!(dc::DualLinearCache, u) partial_vals!(getfield(dc, :partials_u), u) # Update in-place end +function SciMLBase.reinit!(cache::DualLinearCache; + A = nothing, + b = nothing, + u = nothing, + p = nothing, + reuse_precs = false) + if !isnothing(A) + setA!(cache, A) + end + + if !isnothing(b) + setb!(cache, b) + end + + if !isnothing(u) + setu!(cache, u) + end + + if !isnothing(p) + cache.linear_cache.p=p + end + + isfresh = !isnothing(A) + precsisfresh = !reuse_precs && (isfresh || !isnothing(p)) + isfresh |= cache.isfresh + precsisfresh |= cache.linear_cache.precsisfresh + cache.linear_cache.isfresh = true + cache.linear_cache.precsisfresh = precsisfresh + + nothing +end + function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) # If the property is A or b, also update it in the LinearCache if sym === :A @@ -390,7 +422,9 @@ partial_vals!(out, x) = map!(partial_vals, out, x) # Update in-place nodual_value(x) = x nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x) nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact -nodual_value(x::AbstractArray{<:Dual}) = nodual_value!(similar(x, typeof(nodual_value(first(x)))), x) +function nodual_value(x::AbstractArray{<:Dual}) + nodual_value!(similar(x, typeof(nodual_value(first(x)))), x) +end nodual_value!(out, x) = map!(nodual_value, out, x) # Update in-place function update_partials_list!(partial_matrix::AbstractVector{T}, list_cache) where {T} diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 46116e7a3..d70e86fb6 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -188,6 +188,12 @@ backslash_x_p = A \ b @test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) +A[1, 1]+=2 +cache = overload_x_p.cache +reinit!(cache; A = sparse(A)) +overload_x_p = solve!(cache, UMFPACKFactorization()) +backslash_x_p = A \ b +@test ≈ (overload_x_p, backslash_x_p, rtol = 1e-9) # Test that GenericLU doesn't create a DualLinearCache A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) @@ -234,4 +240,4 @@ grad = ForwardDiff.gradient(component_linsolve, p_test) @test grad isa Vector @test length(grad) == 2 @test !any(isnan, grad) -@test !any(isinf, grad) \ No newline at end of file +@test !any(isinf, grad)