@@ -5,8 +5,17 @@ function SciMLBase.solve(
55 sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
66 dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
77 return SciMLBase. build_solution (
8- prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats,
9- sol. original)
8+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
9+ end
10+
11+ function SciMLBase. solve (
12+ prob:: NonlinearLeastSquaresProblem {<: AbstractArray ,
13+ iip, <: Union{<:AbstractArray{<:Dual{T, V, P}}} },
14+ alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ; kwargs... ) where {T, V, P, iip}
15+ sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
16+ dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
17+ return SciMLBase. build_solution (
18+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
1019end
1120
1221for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@@ -56,6 +65,79 @@ function __nlsolve_ad(
5665 return sol, partials
5766end
5867
68+ function __nlsolve_ad (prob:: NonlinearLeastSquaresProblem , alg, args... ; kwargs... )
69+ p = value (prob. p)
70+ u0 = value (prob. u0)
71+ newprob = NonlinearLeastSquaresProblem (prob. f, u0, p; prob. kwargs... )
72+
73+ sol = solve (newprob, alg, args... ; kwargs... )
74+
75+ uu = sol. u
76+
77+ if ! SciMLBase. has_jac (prob. f)
78+ if isinplace (prob)
79+ _F = @closure (du, u, p) -> begin
80+ resid = similar (du, length (sol. resid))
81+ res = DiffResults. DiffResult (
82+ resid, similar (du, length (sol. resid), length (u)))
83+ _f = @closure (du, u) -> prob. f (du, u, p)
84+ ForwardDiff. jacobian! (res, _f, resid, u)
85+ mul! (reshape (du, 1 , :), vec (DiffResults. value (res))' ,
86+ DiffResults. jacobian (res), 2 , false )
87+ return nothing
88+ end
89+ else
90+ # For small problems, nesting ForwardDiff is actually quite fast
91+ if __is_extension_loaded (Val (:Zygote )) && (length (uu) + length (sol. resid) ≥ 50 )
92+ _F = @closure (u, p) -> __zygote_compute_nlls_vjp (prob. f, u, p)
93+ else
94+ _F = @closure (u, p) -> begin
95+ T = promote_type (eltype (u), eltype (p))
96+ res = DiffResults. DiffResult (
97+ similar (u, T, size (sol. resid)), similar (
98+ u, T, length (sol. resid), length (u)))
99+ ForwardDiff. jacobian! (res, Base. Fix2 (prob. f, p), u)
100+ return reshape (
101+ 2 .* vec (DiffResults. value (res))' * DiffResults. jacobian (res),
102+ size (u))
103+ end
104+ end
105+ end
106+ else
107+ if isinplace (prob)
108+ _F = @closure (du, u, p) -> begin
109+ J = similar (du, length (sol. resid), length (u))
110+ prob. jac (J, u, p)
111+ resid = similar (du, length (sol. resid))
112+ prob. f (resid, u, p)
113+ mul! (reshape (du, 1 , :), vec (resid)' , J, 2 , false )
114+ return nothing
115+ end
116+ else
117+ _F = @closure (u, p) -> begin
118+ return reshape (2 .* vec (prob. f (u, p))' * prob. jac (u, p), size (u))
119+ end
120+ end
121+ end
122+
123+ f_p = __nlsolve_∂f_∂p (prob, _F, uu, p)
124+ f_x = __nlsolve_∂f_∂u (prob, _F, uu, p)
125+
126+ z_arr = - f_x \ f_p
127+
128+ pp = prob. p
129+ sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
130+ if uu isa Number
131+ partials = sum (sumfun, zip (z_arr, pp))
132+ elseif p isa Number
133+ partials = sumfun ((z_arr, pp))
134+ else
135+ partials = sum (sumfun, zip (eachcol (z_arr), pp))
136+ end
137+
138+ return sol, partials
139+ end
140+
59141@inline function __nlsolve_∂f_∂p (prob, f:: F , u, p) where {F}
60142 if isinplace (prob)
61143 __f = p -> begin
0 commit comments