@@ -5,8 +5,17 @@ function SciMLBase.solve(
55    sol, partials =  __nlsolve_ad (prob, alg, args... ; kwargs... )
66    dual_soln =  __nlsolve_dual_soln (sol. u, partials, prob. p)
77    return  SciMLBase. build_solution (
8-         prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats,
9-         sol. original)
8+         prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
9+ end 
10+ 
11+ function  SciMLBase. solve (
12+         prob:: NonlinearLeastSquaresProblem {<: AbstractArray ,
13+             iip, <: Union{<:AbstractArray{<:Dual{T, V, P}}} },
14+         alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ; kwargs... ) where  {T, V, P, iip}
15+     sol, partials =  __nlsolve_ad (prob, alg, args... ; kwargs... )
16+     dual_soln =  __nlsolve_dual_soln (sol. u, partials, prob. p)
17+     return  SciMLBase. build_solution (
18+         prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
1019end 
1120
1221for  algType in  (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@@ -24,7 +33,8 @@ for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
2433    end 
2534end 
2635
27- function  __nlsolve_ad (prob, alg, args... ; kwargs... )
36+ function  __nlsolve_ad (
37+         prob:: Union{IntervalNonlinearProblem, NonlinearProblem} , alg, args... ; kwargs... )
2838    p =  value (prob. p)
2939    if  prob isa  IntervalNonlinearProblem
3040        tspan =  value .(prob. tspan)
@@ -55,6 +65,96 @@ function __nlsolve_ad(prob, alg, args...; kwargs...)
5565    return  sol, partials
5666end 
5767
68+ function  __nlsolve_ad (prob:: NonlinearLeastSquaresProblem , alg, args... ; kwargs... )
69+     p =  value (prob. p)
70+     u0 =  value (prob. u0)
71+     newprob =  NonlinearLeastSquaresProblem (prob. f, u0, p; prob. kwargs... )
72+ 
73+     sol =  solve (newprob, alg, args... ; kwargs... )
74+ 
75+     uu =  sol. u
76+ 
77+     #  First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
78+     #  nested autodiff as the last resort
79+     if  SciMLBase. has_vjp (prob. f)
80+         if  isinplace (prob)
81+             _F =  @closure  (du, u, p) ->  begin 
82+                 resid =  similar (du, length (sol. resid))
83+                 prob. f (resid, u, p)
84+                 prob. f. vjp (du, resid, u, p)
85+                 du .*=  2 
86+                 return  nothing 
87+             end 
88+         else 
89+             _F =  @closure  (u, p) ->  begin 
90+                 resid =  prob. f (u, p)
91+                 return  reshape (2  .*  prob. f. vjp (resid, u, p), size (u))
92+             end 
93+         end 
94+     elseif  SciMLBase. has_jac (prob. f)
95+         if  isinplace (prob)
96+             _F =  @closure  (du, u, p) ->  begin 
97+                 J =  similar (du, length (sol. resid), length (u))
98+                 prob. f. jac (J, u, p)
99+                 resid =  similar (du, length (sol. resid))
100+                 prob. f (resid, u, p)
101+                 mul! (reshape (du, 1 , :), vec (resid)' , J, 2 , false )
102+                 return  nothing 
103+             end 
104+         else 
105+             _F =  @closure  (u, p) ->  begin 
106+                 return  reshape (2  .*  vec (prob. f (u, p))'  *  prob. f. jac (u, p), size (u))
107+             end 
108+         end 
109+     else 
110+         if  isinplace (prob)
111+             _F =  @closure  (du, u, p) ->  begin 
112+                 resid =  similar (du, length (sol. resid))
113+                 res =  DiffResults. DiffResult (
114+                     resid, similar (du, length (sol. resid), length (u)))
115+                 _f =  @closure  (du, u) ->  prob. f (du, u, p)
116+                 ForwardDiff. jacobian! (res, _f, resid, u)
117+                 mul! (reshape (du, 1 , :), vec (DiffResults. value (res))' ,
118+                     DiffResults. jacobian (res), 2 , false )
119+                 return  nothing 
120+             end 
121+         else 
122+             #  For small problems, nesting ForwardDiff is actually quite fast
123+             if  __is_extension_loaded (Val (:Zygote )) &&  (length (uu) +  length (sol. resid) ≥  50 )
124+                 _F =  @closure  (u, p) ->  __zygote_compute_nlls_vjp (prob. f, u, p)
125+             else 
126+                 _F =  @closure  (u, p) ->  begin 
127+                     T =  promote_type (eltype (u), eltype (p))
128+                     res =  DiffResults. DiffResult (
129+                         similar (u, T, size (sol. resid)), similar (
130+                             u, T, length (sol. resid), length (u)))
131+                     ForwardDiff. jacobian! (res, Base. Fix2 (prob. f, p), u)
132+                     return  reshape (
133+                         2  .*  vec (DiffResults. value (res))'  *  DiffResults. jacobian (res),
134+                         size (u))
135+                 end 
136+             end 
137+         end 
138+     end 
139+ 
140+     f_p =  __nlsolve_∂f_∂p (prob, _F, uu, p)
141+     f_x =  __nlsolve_∂f_∂u (prob, _F, uu, p)
142+ 
143+     z_arr =  - f_x \  f_p
144+ 
145+     pp =  prob. p
146+     sumfun =  ((z, p),) ->  map (zᵢ ->  zᵢ *  ForwardDiff. partials (p), z)
147+     if  uu isa  Number
148+         partials =  sum (sumfun, zip (z_arr, pp))
149+     elseif  p isa  Number
150+         partials =  sumfun ((z_arr, pp))
151+     else 
152+         partials =  sum (sumfun, zip (eachcol (z_arr), pp))
153+     end 
154+ 
155+     return  sol, partials
156+ end 
157+ 
58158@inline  function  __nlsolve_∂f_∂p (prob, f:: F , u, p) where  {F}
59159    if  isinplace (prob)
60160        __f =  p ->  begin 
0 commit comments