@@ -21,23 +21,24 @@ struct __PotraPtak3 <: AbstractMultiStepScheme end
2121const PotraPtak3 = __PotraPtak3 ()
2222
2323alg_steps (:: __PotraPtak3 ) = 2
24+ nintermediates (:: __PotraPtak3 ) = 1
2425
2526@kwdef @concrete struct __SinghSharma4 <: AbstractMultiStepScheme
26- vjp_autodiff = nothing
27+ jvp_autodiff = nothing
2728end
2829const SinghSharma4 = __SinghSharma4 ()
2930
3031alg_steps (:: __SinghSharma4 ) = 3
3132
3233@kwdef @concrete struct __SinghSharma5 <: AbstractMultiStepScheme
33- vjp_autodiff = nothing
34+ jvp_autodiff = nothing
3435end
3536const SinghSharma5 = __SinghSharma5 ()
3637
3738alg_steps (:: __SinghSharma5 ) = 3
3839
3940@kwdef @concrete struct __SinghSharma7 <: AbstractMultiStepScheme
40- vjp_autodiff = nothing
41+ jvp_autodiff = nothing
4142end
4243const SinghSharma7 = __SinghSharma7 ()
4344
6061
6162Base. show (io:: IO , alg:: GenericMultiStepDescent ) = print (io, " $(alg. scheme) ()" )
6263
63- supports_line_search (:: GenericMultiStepDescent ) = false
64+ supports_line_search (:: GenericMultiStepDescent ) = true
6465supports_trust_region (:: GenericMultiStepDescent ) = false
6566
66- @concrete mutable struct GenericMultiStepDescentCache{S, INV } <: AbstractDescentCache
67+ @concrete mutable struct GenericMultiStepDescentCache{S} <: AbstractDescentCache
6768 f
6869 p
6970 δu
7071 δus
71- extras
72+ u
73+ us
74+ fu
75+ fus
76+ internal_cache
77+ internal_caches
7278 scheme:: S
73- lincache
7479 timer
7580 nf:: Int
7681end
7782
78- @internal_caches GenericMultiStepDescentCache :lincache
83+ # FIXME : @internal_caches needs to be updated to support tuples and namedtuples
84+ # @internal_caches GenericMultiStepDescentCache :internal_caches
7985
8086function __reinit_internal! (cache:: GenericMultiStepDescentCache , args... ; p = cache. p,
8187 kwargs... )
8288 cache. nf = 0
8389 cache. p = p
90+ reset_timer! (cache. timer)
8491end
8592
86- function __δu_caches (scheme:: MSS.__PotraPtak3 , fu, u, :: Val{N} ) where {N}
87- caches = ntuple (N) do i
88- @bb δu = similar (u)
89- @bb y = similar (u)
90- @bb fy = similar (fu)
91- @bb δy = similar (u)
92- @bb u_new = similar (u)
93- (δu, δy, fy, y, u_new)
93+ function __internal_multistep_caches (
94+ scheme:: MSS.__PotraPtak3 , alg:: GenericMultiStepDescent ,
95+ prob, args... ; shared:: Val{N} = Val (1 ), kwargs... ) where {N}
96+ internal_descent = NewtonDescent (; alg. linsolve, alg. precs)
97+ internal_cache = __internal_init (
98+ prob, internal_descent, args... ; kwargs... , shared = Val (2 ))
99+ internal_caches = N ≤ 1 ? nothing :
100+ map (2 : N) do i
101+ __internal_init (prob, internal_descent, args... ; kwargs... , shared = Val (2 ))
94102 end
95- return first (caches), (N ≤ 1 ? nothing : caches[ 2 : end ])
103+ return internal_cache, internal_caches
96104end
97105
98- function __internal_init (prob:: NonlinearProblem , alg:: GenericMultiStepDescent , J, fu, u;
99- shared:: Val{N} = Val (1 ), pre_inverted:: Val{INV} = False, linsolve_kwargs = (;),
106+ function __internal_init (prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} ,
107+ alg:: GenericMultiStepDescent , J, fu, u; shared:: Val{N} = Val (1 ),
108+ pre_inverted:: Val{INV} = False, linsolve_kwargs = (;),
100109 abstol = nothing , reltol = nothing , timer = get_timer_output (),
101110 kwargs... ) where {INV, N}
102- δu, δus = __δu_caches (alg. scheme, fu, u, shared)
103- INV && return GenericMultiStepDescentCache {true} (prob. f, prob. p, δu, δus,
104- alg. scheme, nothing , timer, 0 )
105- lincache = LinearSolverCache (alg, alg. linsolve, J, _vec (fu), _vec (u); abstol, reltol,
106- linsolve_kwargs... )
107- return GenericMultiStepDescentCache {false} (prob. f, prob. p, δu, δus, alg. scheme,
108- lincache, timer, 0 )
109- end
110-
111- function __internal_init (prob:: NonlinearLeastSquaresProblem , alg:: GenericMultiStepDescent ,
112- J, fu, u; kwargs... )
113- error (" Multi-Step Descent Algorithms for NLLS are not implemented yet." )
111+ @bb δu = similar (u)
112+ δus = N ≤ 1 ? nothing : map (2 : N) do i
113+ @bb δu_ = similar (u)
114+ end
115+ fu_cache = ntuple (MSS. nintermediates (alg. scheme)) do i
116+ @bb xx = similar (fu)
117+ end
118+ fus_cache = N ≤ 1 ? nothing : map (2 : N) do i
119+ ntuple (MSS. nintermediates (alg. scheme)) do j
120+ @bb xx = similar (fu)
121+ end
122+ end
123+ u_cache = ntuple (MSS. nintermediates (alg. scheme)) do i
124+ @bb xx = similar (u)
125+ end
126+ us_cache = N ≤ 1 ? nothing : map (2 : N) do i
127+ ntuple (MSS. nintermediates (alg. scheme)) do j
128+ @bb xx = similar (u)
129+ end
130+ end
131+ internal_cache, internal_caches = __internal_multistep_caches (
132+ alg. scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs,
133+ abstol, reltol, timer, kwargs... )
134+ return GenericMultiStepDescentCache (
135+ prob. f, prob. p, δu, δus, u_cache, us_cache, fu_cache, fus_cache,
136+ internal_cache, internal_caches, alg. scheme, timer, 0 )
114137end
115138
116139function __internal_solve! (cache:: GenericMultiStepDescentCache{MSS.__PotraPtak3, INV} , J,
117140 fu, u, idx:: Val = Val (1 ); skip_solve:: Bool = false , new_jacobian:: Bool = true ,
118141 kwargs... ) where {INV}
119- (u_new, δy, fy, y, δu) = get_du (cache, idx)
120- skip_solve && return DescentResult (; u = u_new)
121-
122- @static_timeit cache. timer " linear solve" begin
123- @static_timeit cache. timer " solve and step 1" begin
124- if INV
125- J != = nothing && @bb (δu= J × _vec (fu))
126- else
127- δu = cache. lincache (; A = J, b = _vec (fu), kwargs... , linu = _vec (δu),
128- du = _vec (δu),
129- reuse_A_if_factorization = ! new_jacobian || (idx != = Val (1 )))
130- δu = _restructure (u, δu)
131- end
132- @bb @. y = u - δu
133- end
142+ δu = get_du (cache, idx)
143+ skip_solve && return DescentResult (; δu)
144+
145+ (y,) = get_internal_cache (cache, Val (:u ), idx)
146+ (fy,) = get_internal_cache (cache, Val (:fu ), idx)
147+ internal_cache = get_internal_cache (cache, Val (:internal_cache ), idx)
134148
149+ @static_timeit cache. timer " descent step" begin
150+ result_1 = __internal_solve! (
151+ internal_cache, J, fu, u, Val (1 ); new_jacobian, kwargs... )
152+ δx = result_1. δu
153+
154+ @bb @. y = u + δx
135155 fy = evaluate_f!! (cache. f, fy, y, cache. p)
136156 cache. nf += 1
137157
138- @static_timeit cache. timer " solve and step 2" begin
139- if INV
140- J != = nothing && @bb (δy= J × _vec (fy))
141- else
142- δy = cache. lincache (; A = J, b = _vec (fy), kwargs... , linu = _vec (δy),
143- du = _vec (δy), reuse_A_if_factorization = true )
144- δy = _restructure (u, δy)
145- end
146- @bb @. u_new = y - δy
147- end
158+ result_2 = __internal_solve! (
159+ internal_cache, J, fy, y, Val (2 ); kwargs... )
160+ δy = result_2. δu
161+
162+ @bb @. δu = δx + δy
148163 end
149164
150- set_du! (cache, (u_new, δy, fy, y, δu), idx)
151- return DescentResult (; u = u_new)
165+ set_du! (cache, δu, idx)
166+ set_internal_cache! (cache, (y,), Val (:u ), idx)
167+ set_internal_cache! (cache, (fy,), Val (:fu ), idx)
168+ set_internal_cache! (cache, internal_cache, Val (:internal_cache ), idx)
169+ return DescentResult (; δu)
152170end
0 commit comments