@@ -90,18 +90,33 @@ from NLopt for an example. The common local optimizer arguments are:
9090 - `local_reltol`: relative tolerance in changes of the objective value
9191 - `local_options`: `NamedTuple` of keyword arguments for local optimizer
9292"""
93- function solve (prob:: SciMLBase.OptimizationProblem , alg, args... ;
94- kwargs... ):: SciMLBase.AbstractOptimizationSolution
95- if SciMLBase. has_init (alg)
96- solve! (init (prob, alg, args... ; kwargs... ))
93+ function solve (prob:: SciMLBase.OptimizationProblem , args... ; sensealg = nothing ,
94+ u0 = nothing , p = nothing , wrap = Val (true ), kwargs... ):: SciMLBase.AbstractOptimizationSolution
95+ if sensealg === nothing && haskey (prob. kwargs, :sensealg )
96+ sensealg = prob. kwargs[:sensealg ]
97+ end
98+
99+ u0 = u0 != = nothing ? u0 : prob. u0
100+ p = p != = nothing ? p : prob. p
101+
102+ if wrap isa Val{true }
103+ wrap_sol (solve_up (prob,
104+ sensealg,
105+ u0,
106+ p,
107+ args... ;
108+ originator = SciMLBase. ChainRulesOriginator (),
109+ kwargs... ))
97110 else
98- if prob. u0 != = nothing && ! isconcretetype (eltype (prob. u0))
99- throw (SciMLBase. NonConcreteEltypeError (eltype (prob. u0)))
100- end
101- _check_opt_alg (prob, alg; kwargs... )
102- __solve (prob, alg, args... ; kwargs... )
111+ solve_up (prob,
112+ sensealg,
113+ u0,
114+ p,
115+ args... ;
116+ originator = SciMLBase. ChainRulesOrginator (),
117+ kwargs... )
103118 end
104- end
119+ end
105120
106121function solve (
107122 prob:: SciMLBase.EnsembleProblem{T} , args... ; kwargs... ) where {T < :
@@ -216,3 +231,101 @@ end
216231function __solve (prob:: SciMLBase.OptimizationProblem , alg, args... ; kwargs... )
217232 throw (OptimizerMissingError (alg))
218233end
234+
235+ function solve_up (prob:: SciMLBase.OptimizationProblem , sensealg, u0, p, args... ; originator = SciMLBase. ChainRulesOriginator (),
236+ kwargs... )
237+ alg = extract_opt_alg (args, kwargs, has_kwargs (prob) ? prob. kwargs : kwargs)
238+ _prob = get_concrete_problem (prob; u0 = u0, p = p, kwargs... )
239+ if length (args) < 1
240+ solve_call (_prob, alg, Base. tails (args)... , kwargs... )
241+ else
242+ solve_call (_prob, alg; kwargs... )
243+ end
244+ end
245+
246+ function solve_call (_prob, alg, args... ; merge_callbacks = true , kwargshandle = nothing ,
247+ kwargs... )
248+ kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle
249+ kwargshandle = has_kwargs (_prob) && haskey (_prob. kwargs, :kwargshandle ) ?
250+ _prob. kwargs[:kwargshandle ] : kwargshandle
251+
252+ if has_kwargs (_prob)
253+ kwargs = isempty (_prob. kwargs) ? kwargs : merge (values (_prob. kwargs), kwargs)
254+ end
255+
256+ checkkwargs (kwargshandle; kwargs... )
257+
258+ if SciMLBase. has_init (alg)
259+ solve! (init (_prob, alg, args... ; kwargs... ))
260+ else
261+ if _prob. u0 != = nothing && ! isconcretetype (eltype (_prob. u0))
262+ throw (SciMLBase. NonConcreteEltypeError (eltype (_prob. u0)))
263+ end
264+ _check_opt_alg (prob, alg; kwargs... )
265+ __solve (_prob, alg, args... ; kwargs... )
266+ end
267+ end
268+
269+ function get_concrete_problem (prob:: OptimizationProblem ; kwargs... )
270+ oldprob = prob
271+ prob = get_updated_symbolic_problem (get_root_indp (prob), prob; kwargs... )
272+ if prob != = oldprob
273+ kwargs = (;kwargs... , u0 = SII. state_values (prob), p = SII. parameter_values (prob))
274+ end
275+ p = get_concrete_p (prob, kwargs)
276+ u0 = get_concrete_u0 (prob, false , nothing , kwargs)
277+ u0 = promote_u0 (u0, p, nothing )
278+ remake (prob; u0 = u0, p = p)
279+
280+ end
281+
282+
283+ @inline function extract_opt_alg (solve_args, solve_kwargs, prob_kwargs)
284+ if isempty (solve_args) || isnothing (first (solve_args))
285+ if haskey (solve_kwargs, :alg )
286+ solve_kwargs[:alg ]
287+ elseif haskey (prob_kwargs, :alg )
288+ prob_kwargs[:alg ]
289+ else
290+ nothing
291+ end
292+ else
293+ first (solve_args)
294+ end
295+ end
296+
297+
298+ function _solve_forward (prob, sensealg, u0, p, originator, args... ; merge_callbacks = true ,
299+ kwargs... )
300+ alg = extract_opt_alg (args, kwargs, prob. kwargs)
301+ _prob = get_concrete_problem (prob; u0 = u0, p = p, kwargs... )
302+
303+ if has_kwargs (_prob)
304+ kwargs = isempty (_porb. kwargs) ? kwargs : merge (values (_prob. kwargs), kwargs)
305+ end
306+
307+ if length (args) > 1
308+ _concrete_solve_forward (_prob, alg, sensealg, u0, p, originator,
309+ Base. tail (args)... ; kwargs... )
310+ else
311+ _concrete_solve_forward (_prob, alg, sensealg, u0, p, originator; kwargs... )
312+ end
313+ end
314+
315+ function _solve_adjoint (_prob, sensealg, u0, p, originator, args... ; merge_callbacks = true ,
316+ kwargs... )
317+ alg = extract_alg (args, kwargs, prob. kwargs)
318+
319+ _prob = get_concrete_problem (prob; u0 = u0, p = p, kwargs... )
320+
321+ if has_kwargs (_prob)
322+ kwargs = isempty (_prob. kwargs) ? kwargs : merge (values (_prob. kwargs), kwargs)
323+ end
324+
325+ if length (args) > 1
326+ _concrete_solve_adjoint (_prob, alg, sensealg, u0, p, originator,
327+ Base. tail (args)... ; kwargs... )
328+ else
329+ _concrete_solve_adjoint (_prob, alg, sensealg, u0, p, originator; kwargs... )
330+ end
331+ end
0 commit comments