Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 2a2e77d

Browse files
add HesVecGrad
1 parent 9cd3d85 commit 2a2e77d

File tree

3 files changed

+98
-1
lines changed

3 files changed

+98
-1
lines changed

src/SparseDiffTools.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ export contract_color,
1616
num_hesvec,num_hesvec!,
1717
numauto_hesvec,numauto_hesvec!,
1818
autonum_hesvec,autonum_hesvec!,
19-
JacVec,HesVec
19+
num_hesvecgrad,num_hesvecgrad!,
20+
auto_hesvecgrad,auto_hesvecgrad!,
21+
JacVec,HesVec,HesVecGrad
2022

2123
include("coloring/high_level.jl")
2224
include("coloring/contraction_coloring.jl")

src/differentiation/jaches_products.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,42 @@ function autonum_hesvec(f,x,v)
108108
partials.(g(Dual{DeivVecTag}.(x, v)), 1)
109109
end
110110

111+
function num_hesvecgrad!(du,g,x,v,
112+
cache2 = similar(v),
113+
cache3 = similar(v))
114+
T = eltype(x)
115+
# Should it be min? max? mean?
116+
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
117+
@. x += ϵ*v
118+
g(cache2,x)
119+
@. x -= 2ϵ*v
120+
g(cache3,x)
121+
@. du = (cache2 - cache3)/(2ϵ)
122+
end
123+
124+
function num_hesvecgrad(g,x,v)
125+
T = eltype(x)
126+
# Should it be min? max? mean?
127+
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
128+
x += ϵ*v
129+
gxp = g(x)
130+
x -= 2ϵ*v
131+
gxm = g(x)
132+
(gxp - gxm)/(2ϵ)
133+
end
134+
135+
function auto_hesvecgrad!(du,g,x,v,
136+
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v),
137+
cache3 = ForwardDiff.Dual{DeivVecTag}.(x, v))
138+
cache2 .= Dual{DeivVecTag}.(x, v)
139+
g(cache3,cache2)
140+
du .= partials.(cache3, 1)
141+
end
142+
143+
function auto_hesvecgrad(g,x,v)
144+
partials.(g(Dual{DeivVecTag}.(x, v)), 1)
145+
end
146+
111147
### Operator Forms
112148

113149
mutable struct JacVec{F,T1,T2,uType}
@@ -174,3 +210,34 @@ function LinearAlgebra.mul!(du::AbstractVector,L::HesVec,v::AbstractVector)
174210
num_hesvec!(du,L.f,L.u,v,L.cache1,L.cache2,L.cache3)
175211
end
176212
end
213+
214+
mutable struct HesVecGrad{G,T1,T2,uType}
215+
g::G
216+
cache1::T1
217+
cache2::T2
218+
u::uType
219+
autodiff::Bool
220+
end
221+
222+
function HesVecGrad(g,u::AbstractArray;autodiff=true)
223+
if autodiff
224+
cache1 = ForwardDiff.Dual{DeivVecTag}.(u, u)
225+
cache2 = ForwardDiff.Dual{DeivVecTag}.(u, u)
226+
else
227+
cache1 = similar(u)
228+
cache2 = similar(u)
229+
end
230+
HesVecGrad(g,cache1,cache2,u,autodiff)
231+
end
232+
233+
Base.size(L::HesVecGrad) = (length(L.cache2),length(L.cache2))
234+
Base.size(L::HesVecGrad,i::Int) = length(L.cache2)
235+
Base.:*(L::HesVecGrad,x::AbstractVector) = L.autodiff ? auto_hesvecgrad(L.g,L.u,x) : num_hesvecgrad(L.g,L.u,x)
236+
237+
function LinearAlgebra.mul!(du::AbstractVector,L::HesVecGrad,v::AbstractVector)
238+
if L.autodiff
239+
auto_hesvecgrad!(du,L.g,L.u,v,L.cache1,L.cache2)
240+
else
241+
num_hesvecgrad!(du,L.g,L.u,v,L.cache1,L.cache2)
242+
end
243+
end

test/test_jaches_products.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ f(u) = sum(u.^2)
3131
@test autonum_hesvec!(du, f, x, v, similar(v), cache1, cache2) ForwardDiff.hessian(f,x)*v rtol=1e-2
3232
@test autonum_hesvec(f, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-8
3333

34+
function g(x)
35+
DiffEqDiffTools.finite_difference_gradient(f,x)
36+
end
37+
function g(dx,x)
38+
DiffEqDiffTools.finite_difference_gradient!(dx,f,x)
39+
end
40+
@test num_hesvecgrad!(du, g, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-2
41+
@test num_hesvecgrad!(du, g, x, v, similar(v), similar(v)) ForwardDiff.hessian(f,x)*v rtol=1e-2
42+
@test num_hesvecgrad(g, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-2
43+
44+
@test auto_hesvecgrad!(du, g, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-2
45+
@test auto_hesvecgrad!(du, g, x, v, cache1, cache2) ForwardDiff.hessian(f,x)*v rtol=1e-2
46+
@test auto_hesvecgrad(g, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-2
47+
3448
f(du,u) = mul!(du,A,u)
3549
f(u) = A*u
3650
L = JacVec(f,x)
@@ -72,3 +86,17 @@ L.u .= v
7286
### Integration test with IterativeSolvers
7387
out = similar(v)
7488
gmres!(out, L, v)
89+
90+
L = HesVecGrad(g,x,autodiff=false)
91+
@test L*x num_hesvec(f, x, x)
92+
@test L*v num_hesvec(f, x, v)
93+
@test mul!(du,L,v) num_hesvec(f, x, v) rtol=1e-2
94+
L.u .= v
95+
@test mul!(du,L,v) num_hesvec(f, v, v) rtol=1e-2
96+
97+
L = HesVecGrad(g,x)
98+
@test L*x numauto_hesvec(f, x, x)
99+
@test L*v numauto_hesvec(f, x, v)
100+
@test mul!(du,L,v) numauto_hesvec(f, x, v) rtol=1e-8
101+
L.u .= v
102+
@test mul!(du,L,v) numauto_hesvec(f, v, v) rtol=1e-8

0 commit comments

Comments
 (0)