@@ -21,23 +21,24 @@ struct __PotraPtak3 <: AbstractMultiStepScheme end
2121const  PotraPtak3 =  __PotraPtak3 ()
2222
2323alg_steps (:: __PotraPtak3 ) =  2 
24+ nintermediates (:: __PotraPtak3 ) =  1 
2425
2526@kwdef  @concrete  struct  __SinghSharma4 <:  AbstractMultiStepScheme 
26-     vjp_autodiff  =  nothing 
27+     jvp_autodiff  =  nothing 
2728end 
2829const  SinghSharma4 =  __SinghSharma4 ()
2930
3031alg_steps (:: __SinghSharma4 ) =  3 
3132
3233@kwdef  @concrete  struct  __SinghSharma5 <:  AbstractMultiStepScheme 
33-     vjp_autodiff  =  nothing 
34+     jvp_autodiff  =  nothing 
3435end 
3536const  SinghSharma5 =  __SinghSharma5 ()
3637
3738alg_steps (:: __SinghSharma5 ) =  3 
3839
3940@kwdef  @concrete  struct  __SinghSharma7 <:  AbstractMultiStepScheme 
40-     vjp_autodiff  =  nothing 
41+     jvp_autodiff  =  nothing 
4142end 
4243const  SinghSharma7 =  __SinghSharma7 ()
4344
6061
6162Base. show (io:: IO , alg:: GenericMultiStepDescent ) =  print (io, " $(alg. scheme) ()"  )
6263
63- supports_line_search (:: GenericMultiStepDescent ) =  false 
64+ supports_line_search (:: GenericMultiStepDescent ) =  true 
6465supports_trust_region (:: GenericMultiStepDescent ) =  false 
6566
66- @concrete  mutable struct  GenericMultiStepDescentCache{S, INV } <:  AbstractDescentCache 
67+ @concrete  mutable struct  GenericMultiStepDescentCache{S} <:  AbstractDescentCache 
6768    f
6869    p
6970    δu
7071    δus
71-     extras
72+     u
73+     us
74+     fu
75+     fus
76+     internal_cache
77+     internal_caches
7278    scheme:: S 
73-     lincache
7479    timer
7580    nf:: Int 
7681end 
7782
78- @internal_caches  GenericMultiStepDescentCache :lincache 
83+ #  FIXME : @internal_caches needs to be updated to support tuples and namedtuples
84+ #  @internal_caches GenericMultiStepDescentCache :internal_caches
7985
8086function  __reinit_internal! (cache:: GenericMultiStepDescentCache , args... ; p =  cache. p,
8187        kwargs... )
8288    cache. nf =  0 
8389    cache. p =  p
90+     reset_timer! (cache. timer)
8491end 
8592
86- function  __δu_caches (scheme:: MSS.__PotraPtak3 , fu, u, :: Val{N} ) where  {N}
87-     caches =  ntuple (N) do  i
88-         @bb  δu =  similar (u)
89-         @bb  y =  similar (u)
90-         @bb  fy =  similar (fu)
91-         @bb  δy =  similar (u)
92-         @bb  u_new =  similar (u)
93-         (δu, δy, fy, y, u_new)
93+ function  __internal_multistep_caches (
94+         scheme:: MSS.__PotraPtak3 , alg:: GenericMultiStepDescent ,
95+         prob, args... ; shared:: Val{N}  =  Val (1 ), kwargs... ) where  {N}
96+     internal_descent =  NewtonDescent (; alg. linsolve, alg. precs)
97+     internal_cache =  __internal_init (
98+         prob, internal_descent, args... ; kwargs... , shared =  Val (2 ))
99+     internal_caches =  N ≤  1  ?  nothing  : 
100+                       map (2 : N) do  i
101+         __internal_init (prob, internal_descent, args... ; kwargs... , shared =  Val (2 ))
94102    end 
95-     return  first (caches), (N  ≤   1   ?   nothing   :  caches[ 2 : end ]) 
103+     return  internal_cache, internal_caches 
96104end 
97105
98- function  __internal_init (prob:: NonlinearProblem , alg:: GenericMultiStepDescent , J, fu, u;
99-         shared:: Val{N}  =  Val (1 ), pre_inverted:: Val{INV}  =  False, linsolve_kwargs =  (;),
106+ function  __internal_init (prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} ,
107+         alg:: GenericMultiStepDescent , J, fu, u; shared:: Val{N}  =  Val (1 ),
108+         pre_inverted:: Val{INV}  =  False, linsolve_kwargs =  (;),
100109        abstol =  nothing , reltol =  nothing , timer =  get_timer_output (),
101110        kwargs... ) where  {INV, N}
102-     δu, δus =  __δu_caches (alg. scheme, fu, u, shared)
103-     INV &&  return  GenericMultiStepDescentCache {true} (prob. f, prob. p, δu, δus,
104-         alg. scheme, nothing , timer, 0 )
105-     lincache =  LinearSolverCache (alg, alg. linsolve, J, _vec (fu), _vec (u); abstol, reltol,
106-         linsolve_kwargs... )
107-     return  GenericMultiStepDescentCache {false} (prob. f, prob. p, δu, δus, alg. scheme,
108-         lincache, timer, 0 )
109- end 
110- 
111- function  __internal_init (prob:: NonlinearLeastSquaresProblem , alg:: GenericMultiStepDescent ,
112-         J, fu, u; kwargs... )
113-     error (" Multi-Step Descent Algorithms for NLLS are not implemented yet."  )
111+     @bb  δu =  similar (u)
112+     δus =  N ≤  1  ?  nothing  :  map (2 : N) do  i
113+         @bb  δu_ =  similar (u)
114+     end 
115+     fu_cache =  ntuple (MSS. nintermediates (alg. scheme)) do  i
116+         @bb  xx =  similar (fu)
117+     end 
118+     fus_cache =  N ≤  1  ?  nothing  :  map (2 : N) do  i
119+         ntuple (MSS. nintermediates (alg. scheme)) do  j
120+             @bb  xx =  similar (fu)
121+         end 
122+     end 
123+     u_cache =  ntuple (MSS. nintermediates (alg. scheme)) do  i
124+         @bb  xx =  similar (u)
125+     end 
126+     us_cache =  N ≤  1  ?  nothing  :  map (2 : N) do  i
127+         ntuple (MSS. nintermediates (alg. scheme)) do  j
128+             @bb  xx =  similar (u)
129+         end 
130+     end 
131+     internal_cache, internal_caches =  __internal_multistep_caches (
132+         alg. scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs,
133+         abstol, reltol, timer, kwargs... )
134+     return  GenericMultiStepDescentCache (
135+         prob. f, prob. p, δu, δus, u_cache, us_cache, fu_cache, fus_cache,
136+         internal_cache, internal_caches, alg. scheme, timer, 0 )
114137end 
115138
116139function  __internal_solve! (cache:: GenericMultiStepDescentCache{MSS.__PotraPtak3, INV} , J,
117140        fu, u, idx:: Val  =  Val (1 ); skip_solve:: Bool  =  false , new_jacobian:: Bool  =  true ,
118141        kwargs... ) where  {INV}
119-     (u_new, δy, fy, y, δu) =  get_du (cache, idx)
120-     skip_solve &&  return  DescentResult (; u =  u_new)
121- 
122-     @static_timeit  cache. timer " linear solve"   begin 
123-         @static_timeit  cache. timer " solve and step 1"   begin 
124-             if  INV
125-                 J != =  nothing  &&  @bb (δu= J ×  _vec (fu))
126-             else 
127-                 δu =  cache. lincache (; A =  J, b =  _vec (fu), kwargs... , linu =  _vec (δu),
128-                     du =  _vec (δu),
129-                     reuse_A_if_factorization =  ! new_jacobian ||  (idx != =  Val (1 )))
130-                 δu =  _restructure (u, δu)
131-             end 
132-             @bb  @.  y =  u -  δu
133-         end 
142+     δu =  get_du (cache, idx)
143+     skip_solve &&  return  DescentResult (; δu)
144+ 
145+     (y,) =  get_internal_cache (cache, Val (:u ), idx)
146+     (fy,) =  get_internal_cache (cache, Val (:fu ), idx)
147+     internal_cache =  get_internal_cache (cache, Val (:internal_cache ), idx)
134148
149+     @static_timeit  cache. timer " descent step"   begin 
150+         result_1 =  __internal_solve! (
151+             internal_cache, J, fu, u, Val (1 ); new_jacobian, kwargs... )
152+         δx =  result_1. δu
153+ 
154+         @bb  @.  y =  u +  δx
135155        fy =  evaluate_f!! (cache. f, fy, y, cache. p)
136156        cache. nf +=  1 
137157
138-         @static_timeit  cache. timer " solve and step 2"   begin 
139-             if  INV
140-                 J != =  nothing  &&  @bb (δy= J ×  _vec (fy))
141-             else 
142-                 δy =  cache. lincache (; A =  J, b =  _vec (fy), kwargs... , linu =  _vec (δy),
143-                     du =  _vec (δy), reuse_A_if_factorization =  true )
144-                 δy =  _restructure (u, δy)
145-             end 
146-             @bb  @.  u_new =  y -  δy
147-         end 
158+         result_2 =  __internal_solve! (
159+             internal_cache, J, fy, y, Val (2 ); kwargs... )
160+         δy =  result_2. δu
161+ 
162+         @bb  @.  δu =  δx +  δy
148163    end 
149164
150-     set_du! (cache, (u_new, δy, fy, y, δu), idx)
151-     return  DescentResult (; u =  u_new)
165+     set_du! (cache, δu, idx)
166+     set_internal_cache! (cache, (y,), Val (:u ), idx)
167+     set_internal_cache! (cache, (fy,), Val (:fu ), idx)
168+     set_internal_cache! (cache, internal_cache, Val (:internal_cache ), idx)
169+     return  DescentResult (; δu)
152170end 
0 commit comments