Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit 6bd3b78

Browse files
Update function.jl
Changed the position of the MOO function checks in instantiate_function.
1 parent 8ed3d02 commit 6bd3b78

File tree

1 file changed

+54
-49
lines changed

1 file changed

+54
-49
lines changed

src/function.jl

Lines changed: 54 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,17 @@ function that is not defined, an error is thrown.
4343
For more information on the use of automatic differentiation, see the
4444
documentation of the `AbstractADType` types.
4545
"""
46-
function instantiate_function(f, x, ::SciMLBase.NoAD,
46+
47+
48+
function instantiate_function(f::MultiObjectiveOptimizationFunction, x, ::SciMLBase.NoAD,
4749
p, num_cons = 0)
48-
grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...)
49-
hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...)
50+
jac = f.jac === nothing ? nothing : (J, x, args...) -> f.jac(J, x, p, args...)
51+
hess = f.hess === nothing ? nothing : [(H, x, args...) -> h(H, x, p, args...) for h in f.hess]
5052
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...)
5153
cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p)
5254
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p)
55+
cons_jvp = f.cons_jvp === nothing ? nothing : (res, x) -> f.cons_jvp(res, x, p)
56+
cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, p)
5357
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p)
5458
hess_prototype = f.hess_prototype === nothing ? nothing :
5559
convert.(eltype(x), f.hess_prototype)
@@ -61,9 +65,9 @@ function instantiate_function(f, x, ::SciMLBase.NoAD,
6165
expr = symbolify(f.expr)
6266
cons_expr = symbolify.(f.cons_expr)
6367

64-
return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad = grad, hess = hess,
68+
return MultiObjectiveOptimizationFunction{true}(f.f, SciMLBase.NoAD(); jac = jac, hess = hess,
6569
hv = hv,
66-
cons = cons, cons_j = cons_j, cons_h = cons_h,
70+
cons = cons, cons_j = cons_j, cons_jvp = cons_jvp, cons_vjp = cons_vjp, cons_h = cons_h,
6771
hess_prototype = hess_prototype,
6872
cons_jac_prototype = cons_jac_prototype,
6973
cons_hess_prototype = cons_hess_prototype,
@@ -72,13 +76,28 @@ function instantiate_function(f, x, ::SciMLBase.NoAD,
7276
observed = f.observed)
7377
end
7478

75-
function instantiate_function(f, cache::ReInitCache, ::SciMLBase.NoAD,
79+
function instantiate_function(f::MultiObjectiveOptimizationFunction, x, adtype::ADTypes.AbstractADType,
80+
p, num_cons = 0)
81+
adtypestr = string(adtype)
82+
_strtind = findfirst('.', adtypestr)
83+
strtind = isnothing(_strtind) ? 5 : _strtind + 5
84+
open_nrmlbrkt_ind = findfirst('(', adtypestr)
85+
open_squigllybrkt_ind = findfirst('{', adtypestr)
86+
open_brkt_ind = isnothing(open_squigllybrkt_ind) ? open_nrmlbrkt_ind :
87+
min(open_nrmlbrkt_ind, open_squigllybrkt_ind)
88+
adpkg = adtypestr[strtind:(open_brkt_ind - 1)]
89+
throw(ArgumentError("The passed automatic differentiation backend choice is not available. Please load the corresponding AD package $adpkg."))
90+
end
91+
92+
function instantiate_function(f::MultiObjectiveOptimizationFunction, cache::ReInitCache, ::SciMLBase.NoAD,
7693
num_cons = 0)
77-
grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, cache.p, args...)
78-
hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, cache.p, args...)
94+
jac = f.jac === nothing ? nothing : (J, x, args...) -> f.jac(J, x, cache.p, args...)
95+
hess = f.hess === nothing ? nothing : [(H, x, args...) -> h(H, x, cache.p, args...) for h in f.hess]
7996
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, cache.p, args...)
8097
cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, cache.p)
8198
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, cache.p)
99+
cons_jvp = f.cons_jvp === nothing ? nothing : (res, x) -> f.cons_jvp(res, x, cache.p)
100+
cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, cache.p)
82101
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, cache.p)
83102
hess_prototype = f.hess_prototype === nothing ? nothing :
84103
convert.(eltype(cache.u0), f.hess_prototype)
@@ -90,9 +109,9 @@ function instantiate_function(f, cache::ReInitCache, ::SciMLBase.NoAD,
90109
expr = symbolify(f.expr)
91110
cons_expr = symbolify.(f.cons_expr)
92111

93-
return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad = grad, hess = hess,
112+
return MultiObjectiveOptimizationFunction{true}(f.f, SciMLBase.NoAD(); jac = jac, hess = hess,
94113
hv = hv,
95-
cons = cons, cons_j = cons_j, cons_h = cons_h,
114+
cons = cons, cons_j = cons_j, cons_jvp = cons_jvp, cons_vjp = cons_vjp, cons_h = cons_h,
96115
hess_prototype = hess_prototype,
97116
cons_jac_prototype = cons_jac_prototype,
98117
cons_hess_prototype = cons_hess_prototype,
@@ -101,28 +120,14 @@ function instantiate_function(f, cache::ReInitCache, ::SciMLBase.NoAD,
101120
observed = f.observed)
102121
end
103122

104-
function instantiate_function(f, x, adtype::ADTypes.AbstractADType,
105-
p, num_cons = 0)
106-
adtypestr = string(adtype)
107-
_strtind = findfirst('.', adtypestr)
108-
strtind = isnothing(_strtind) ? 5 : _strtind + 5
109-
open_nrmlbrkt_ind = findfirst('(', adtypestr)
110-
open_squigllybrkt_ind = findfirst('{', adtypestr)
111-
open_brkt_ind = isnothing(open_squigllybrkt_ind) ? open_nrmlbrkt_ind :
112-
min(open_nrmlbrkt_ind, open_squigllybrkt_ind)
113-
adpkg = adtypestr[strtind:(open_brkt_ind - 1)]
114-
throw(ArgumentError("The passed automatic differentiation backend choice is not available. Please load the corresponding AD package $adpkg."))
115-
end
116123

117-
function instantiate_function(f::MultiObjectiveOptimizationFunction, x, ::SciMLBase.NoAD,
124+
function instantiate_function(f, x, ::SciMLBase.NoAD,
118125
p, num_cons = 0)
119-
jac = f.jac === nothing ? nothing : (J, x, args...) -> f.jac(J, x, p, args...)
120-
hess = f.hess === nothing ? nothing : [(H, x, args...) -> h(H, x, p, args...) for h in f.hess]
126+
grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...)
127+
hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...)
121128
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...)
122129
cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p)
123130
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p)
124-
cons_jvp = f.cons_jvp === nothing ? nothing : (res, x) -> f.cons_jvp(res, x, p)
125-
cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, p)
126131
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p)
127132
hess_prototype = f.hess_prototype === nothing ? nothing :
128133
convert.(eltype(x), f.hess_prototype)
@@ -134,9 +139,9 @@ function instantiate_function(f::MultiObjectiveOptimizationFunction, x, ::SciMLB
134139
expr = symbolify(f.expr)
135140
cons_expr = symbolify.(f.cons_expr)
136141

137-
return MultiObjectiveOptimizationFunction{true}(f.f, SciMLBase.NoAD(); jac = jac, hess = hess,
142+
return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad = grad, hess = hess,
138143
hv = hv,
139-
cons = cons, cons_j = cons_j, cons_jvp = cons_jvp, cons_vjp = cons_vjp, cons_h = cons_h,
144+
cons = cons, cons_j = cons_j, cons_h = cons_h,
140145
hess_prototype = hess_prototype,
141146
cons_jac_prototype = cons_jac_prototype,
142147
cons_hess_prototype = cons_hess_prototype,
@@ -145,28 +150,13 @@ function instantiate_function(f::MultiObjectiveOptimizationFunction, x, ::SciMLB
145150
observed = f.observed)
146151
end
147152

148-
function instantiate_function(f::MultiObjectiveOptimizationFunction, x, adtype::ADTypes.AbstractADType,
149-
p, num_cons = 0)
150-
adtypestr = string(adtype)
151-
_strtind = findfirst('.', adtypestr)
152-
strtind = isnothing(_strtind) ? 5 : _strtind + 5
153-
open_nrmlbrkt_ind = findfirst('(', adtypestr)
154-
open_squigllybrkt_ind = findfirst('{', adtypestr)
155-
open_brkt_ind = isnothing(open_squigllybrkt_ind) ? open_nrmlbrkt_ind :
156-
min(open_nrmlbrkt_ind, open_squigllybrkt_ind)
157-
adpkg = adtypestr[strtind:(open_brkt_ind - 1)]
158-
throw(ArgumentError("The passed automatic differentiation backend choice is not available. Please load the corresponding AD package $adpkg."))
159-
end
160-
161-
function instantiate_function(f::MultiObjectiveOptimizationFunction, cache::ReInitCache, ::SciMLBase.NoAD,
153+
function instantiate_function(f, cache::ReInitCache, ::SciMLBase.NoAD,
162154
num_cons = 0)
163-
jac = f.jac === nothing ? nothing : (J, x, args...) -> f.jac(J, x, cache.p, args...)
164-
hess = f.hess === nothing ? nothing : [(H, x, args...) -> h(H, x, cache.p, args...) for h in f.hess]
155+
grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, cache.p, args...)
156+
hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, cache.p, args...)
165157
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, cache.p, args...)
166158
cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, cache.p)
167159
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, cache.p)
168-
cons_jvp = f.cons_jvp === nothing ? nothing : (res, x) -> f.cons_jvp(res, x, cache.p)
169-
cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, cache.p)
170160
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, cache.p)
171161
hess_prototype = f.hess_prototype === nothing ? nothing :
172162
convert.(eltype(cache.u0), f.hess_prototype)
@@ -178,13 +168,28 @@ function instantiate_function(f::MultiObjectiveOptimizationFunction, cache::ReIn
178168
expr = symbolify(f.expr)
179169
cons_expr = symbolify.(f.cons_expr)
180170

181-
return MultiObjectiveOptimizationFunction{true}(f.f, SciMLBase.NoAD(); jac = jac, hess = hess,
171+
return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad = grad, hess = hess,
182172
hv = hv,
183-
cons = cons, cons_j = cons_j, cons_jvp = cons_jvp, cons_vjp = cons_vjp, cons_h = cons_h,
173+
cons = cons, cons_j = cons_j, cons_h = cons_h,
184174
hess_prototype = hess_prototype,
185175
cons_jac_prototype = cons_jac_prototype,
186176
cons_hess_prototype = cons_hess_prototype,
187177
expr = expr, cons_expr = cons_expr,
188178
sys = f.sys,
189179
observed = f.observed)
190180
end
181+
182+
function instantiate_function(f, x, adtype::ADTypes.AbstractADType,
183+
p, num_cons = 0)
184+
adtypestr = string(adtype)
185+
_strtind = findfirst('.', adtypestr)
186+
strtind = isnothing(_strtind) ? 5 : _strtind + 5
187+
open_nrmlbrkt_ind = findfirst('(', adtypestr)
188+
open_squigllybrkt_ind = findfirst('{', adtypestr)
189+
open_brkt_ind = isnothing(open_squigllybrkt_ind) ? open_nrmlbrkt_ind :
190+
min(open_nrmlbrkt_ind, open_squigllybrkt_ind)
191+
adpkg = adtypestr[strtind:(open_brkt_ind - 1)]
192+
throw(ArgumentError("The passed automatic differentiation backend choice is not available. Please load the corresponding AD package $adpkg."))
193+
end
194+
195+

0 commit comments

Comments
 (0)