@@ -163,8 +163,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem{uType, iip},
163163
164164 linsolve = get_linear_solver (alg. descent)
165165 initialization_cache = __internal_init (prob, alg. initialization, alg, f, fu, u, p;
166- linsolve,
167- maxiters, internalnorm)
166+ linsolve, maxiters, internalnorm)
168167
169168 abstol, reltol, termination_cache = init_termination_cache (abstol, reltol, fu, u,
170169 termination_condition)
@@ -222,9 +221,7 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
222221 new_jacobian = true
223222 @static_timeit cache. timer " jacobian init/reinit" begin
224223 if get_nsteps (cache) == 0 # First Step is special ignore kwargs
225- J_init = __internal_solve! (cache. initialization_cache,
226- cache. fu,
227- cache. u,
224+ J_init = __internal_solve! (cache. initialization_cache, cache. fu, cache. u,
228225 Val (false ))
229226 if INV
230227 if jacobian_initialized_preinverted (cache. initialization_cache. alg)
@@ -283,52 +280,65 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
283280 @static_timeit cache. timer " descent" begin
284281 if cache. trustregion_cache != = nothing &&
285282 hasfield (typeof (cache. trustregion_cache), :trust_region )
286- δu, descent_success, descent_intermediates = __internal_solve! (cache. descent_cache,
287- J, cache. fu, cache. u; new_jacobian,
288- trust_region = cache. trustregion_cache. trust_region)
283+ descent_result = __internal_solve! (cache. descent_cache, J, cache. fu, cache. u;
284+ new_jacobian, trust_region = cache. trustregion_cache. trust_region)
289285 else
290- δu, descent_success, descent_intermediates = __internal_solve! (cache. descent_cache,
291- J, cache . fu, cache . u; new_jacobian)
286+ descent_result = __internal_solve! (cache. descent_cache, J, cache . fu, cache . u;
287+ new_jacobian)
292288 end
293289 end
294290
295- if descent_success
296- if GB === :LineSearch
297- @static_timeit cache. timer " linesearch" begin
298- needs_reset, α = __internal_solve! (cache. linesearch_cache, cache. u, δu)
299- end
300- if needs_reset && cache. steps_since_last_reset > 5 # Reset after a burn-in period
301- cache. force_reinit = true
302- else
303- @static_timeit cache. timer " step" begin
304- @bb axpy! (α, δu, cache. u)
305- evaluate_f! (cache, cache. u, cache. p)
306- end
307- end
308- elseif GB === :TrustRegion
309- @static_timeit cache. timer " trustregion" begin
310- tr_accepted, u_new, fu_new = __internal_solve! (cache. trustregion_cache, J,
311- cache. fu, cache. u, δu, descent_intermediates)
312- if tr_accepted
313- @bb copyto! (cache. u, u_new)
314- @bb copyto! (cache. fu, fu_new)
315- end
316- if hasfield (typeof (cache. trustregion_cache), :shrink_counter ) &&
317- cache. trustregion_cache. shrink_counter > cache. max_shrink_times
318- cache. retcode = ReturnCode. ShrinkThresholdExceeded
319- cache. force_stop = true
320- end
321- end
322- α = true
323- elseif GB === :None
291+ if descent_result. success
292+ if GB === :None
324293 @static_timeit cache. timer " step" begin
325- @bb axpy! (1 , δu, cache. u)
294+ if descent_result. u != = missing
295+ @bb copyto! (cache. u, descent_result. u)
296+ elseif descent_result. δu != = missing
297+ @bb axpy! (1 , descent_result. δu, cache. u)
298+ else
299+ error (" This shouldn't occur. `$(cache. alg. descent) ` is incorrectly \
300+ specified." )
301+ end
326302 evaluate_f! (cache, cache. u, cache. p)
327303 end
328304 α = true
329305 else
330- error (" Unknown Globalization Strategy: $(GB) . Allowed values are (:LineSearch, \
331- :TrustRegion, :None)" )
306+ δu = descent_result. δu
307+ @assert δu!= = missing " Descent Supporting LineSearch or TrustRegion must return a `δu`."
308+
309+ if GB === :LineSearch
310+ @static_timeit cache. timer " linesearch" begin
311+ needs_reset, α = __internal_solve! (cache. linesearch_cache, cache. u, δu)
312+ end
313+ if needs_reset && cache. steps_since_last_reset > 5 # Reset after a burn-in period
314+ cache. force_reinit = true
315+ else
316+ @static_timeit cache. timer " step" begin
317+ @bb axpy! (α, δu, cache. u)
318+ evaluate_f! (cache, cache. u, cache. p)
319+ end
320+ end
321+ elseif GB === :TrustRegion
322+ @static_timeit cache. timer " trustregion" begin
323+ tr_accepted, u_new, fu_new = __internal_solve! (cache. trustregion_cache,
324+ J, cache. fu, cache. u, δu, descent_result. extras)
325+ if tr_accepted
326+ @bb copyto! (cache. u, u_new)
327+ @bb copyto! (cache. fu, fu_new)
328+ α = true
329+ else
330+ α = false
331+ end
332+ if hasfield (typeof (cache. trustregion_cache), :shrink_counter ) &&
333+ cache. trustregion_cache. shrink_counter > cache. max_shrink_times
334+ cache. retcode = ReturnCode. ShrinkThresholdExceeded
335+ cache. force_stop = true
336+ end
337+ end
338+ else
339+ error (" Unknown Globalization Strategy: $(GB) . Allowed values are \
340+ (:LineSearch, :TrustRegion, :None)" )
341+ end
332342 end
333343 check_and_update! (cache, cache. fu, cache. u, cache. u_cache)
334344 else
0 commit comments