@@ -3,10 +3,112 @@ module OptimizationMTKExt
3
3
import OptimizationBase, OptimizationBase. ArrayInterface
4
4
import OptimizationBase. SciMLBase
5
5
import OptimizationBase. SciMLBase: OptimizationFunction
6
- import OptimizationBase. ADTypes: AutoModelingToolkit
6
+ import OptimizationBase. ADTypes: AutoModelingToolkit, AutoSymbolics, AutoSparse
7
7
isdefined (Base, :get_extension ) ? (using ModelingToolkit) : (using .. ModelingToolkit)
8
8
9
- function OptimizationBase. instantiate_function (f, x, adtype:: AutoModelingToolkit , p,
9
+ function OptimizationBase. ADTypes. AutoModelingToolkit (sparse = false , cons_sparse = false )
10
+ if sparse || cons_sparse
11
+ return AutoSparse (AutoSymbolics ())
12
+ else
13
+ return AutoSymbolics ()
14
+ end
15
+ end
16
+
17
+ function OptimizationBase. instantiate_function (
18
+ f, x, adtype:: AutoSparse{<:AutoSymbolics, S, C} , p,
19
+ num_cons = 0 ) where {S, C}
20
+ p = isnothing (p) ? SciMLBase. NullParameters () : p
21
+
22
+ sys = complete (ModelingToolkit. modelingtoolkitize (OptimizationProblem (f, x, p;
23
+ lcons = fill (0.0 ,
24
+ num_cons),
25
+ ucons = fill (0.0 ,
26
+ num_cons))))
27
+ # sys = ModelingToolkit.structural_simplify(sys)
28
+ f = OptimizationProblem (sys, x, p, grad = true , hess = true ,
29
+ sparse = true , cons_j = true , cons_h = true ,
30
+ cons_sparse = true ). f
31
+
32
+ grad = (G, θ, args... ) -> f. grad (G, θ, p, args... )
33
+
34
+ hess = (H, θ, args... ) -> f. hess (H, θ, p, args... )
35
+
36
+ hv = function (H, θ, v, args... )
37
+ res = adtype. obj_sparse ? (eltype (θ)). (f. hess_prototype) :
38
+ ArrayInterface. zeromatrix (θ)
39
+ hess (res, θ, args... )
40
+ H .= res * v
41
+ end
42
+
43
+ if ! isnothing (f. cons)
44
+ cons = (res, θ) -> f. cons (res, θ, p)
45
+ cons_j = (J, θ) -> f. cons_j (J, θ, p)
46
+ cons_h = (res, θ) -> f. cons_h (res, θ, p)
47
+ else
48
+ cons = nothing
49
+ cons_j = nothing
50
+ cons_h = nothing
51
+ end
52
+
53
+ return OptimizationFunction {true} (f. f, adtype; grad = grad, hess = hess, hv = hv,
54
+ cons = cons, cons_j = cons_j, cons_h = cons_h,
55
+ hess_prototype = f. hess_prototype,
56
+ cons_jac_prototype = f. cons_jac_prototype,
57
+ cons_hess_prototype = f. cons_hess_prototype,
58
+ expr = OptimizationBase. symbolify (f. expr),
59
+ cons_expr = OptimizationBase. symbolify .(f. cons_expr),
60
+ sys = sys,
61
+ observed = f. observed)
62
+ end
63
+
64
+ function OptimizationBase. instantiate_function (f, cache:: OptimizationBase.ReInitCache ,
65
+ adtype:: AutoSparse{<:AutoSymbolics, S, C} , num_cons = 0 ) where {S, C}
66
+ p = isnothing (cache. p) ? SciMLBase. NullParameters () : cache. p
67
+
68
+ sys = complete (ModelingToolkit. modelingtoolkitize (OptimizationProblem (f, cache. u0,
69
+ cache. p;
70
+ lcons = fill (0.0 ,
71
+ num_cons),
72
+ ucons = fill (0.0 ,
73
+ num_cons))))
74
+ # sys = ModelingToolkit.structural_simplify(sys)
75
+ f = OptimizationProblem (sys, cache. u0, cache. p, grad = true , hess = true ,
76
+ sparse = true , cons_j = true , cons_h = true ,
77
+ cons_sparse = true ). f
78
+
79
+ grad = (G, θ, args... ) -> f. grad (G, θ, cache. p, args... )
80
+
81
+ hess = (H, θ, args... ) -> f. hess (H, θ, cache. p, args... )
82
+
83
+ hv = function (H, θ, v, args... )
84
+ res = adtype. obj_sparse ? (eltype (θ)). (f. hess_prototype) :
85
+ ArrayInterface. zeromatrix (θ)
86
+ hess (res, θ, args... )
87
+ H .= res * v
88
+ end
89
+
90
+ if ! isnothing (f. cons)
91
+ cons = (res, θ) -> f. cons (res, θ, cache. p)
92
+ cons_j = (J, θ) -> f. cons_j (J, θ, cache. p)
93
+ cons_h = (res, θ) -> f. cons_h (res, θ, cache. p)
94
+ else
95
+ cons = nothing
96
+ cons_j = nothing
97
+ cons_h = nothing
98
+ end
99
+
100
+ return OptimizationFunction {true} (f. f, adtype; grad = grad, hess = hess, hv = hv,
101
+ cons = cons, cons_j = cons_j, cons_h = cons_h,
102
+ hess_prototype = f. hess_prototype,
103
+ cons_jac_prototype = f. cons_jac_prototype,
104
+ cons_hess_prototype = f. cons_hess_prototype,
105
+ expr = OptimizationBase. symbolify (f. expr),
106
+ cons_expr = OptimizationBase. symbolify .(f. cons_expr),
107
+ sys = sys,
108
+ observed = f. observed)
109
+ end
110
+
111
+ function OptimizationBase. instantiate_function (f, x, adtype:: AutoSymbolics , p,
10
112
num_cons = 0 )
11
113
p = isnothing (p) ? SciMLBase. NullParameters () : p
12
114
@@ -17,8 +119,8 @@ function OptimizationBase.instantiate_function(f, x, adtype::AutoModelingToolkit
17
119
num_cons))))
18
120
# sys = ModelingToolkit.structural_simplify(sys)
19
121
f = OptimizationProblem (sys, x, p, grad = true , hess = true ,
20
- sparse = adtype . obj_sparse , cons_j = true , cons_h = true ,
21
- cons_sparse = adtype . cons_sparse ). f
122
+ sparse = false , cons_j = true , cons_h = true ,
123
+ cons_sparse = false ). f
22
124
23
125
grad = (G, θ, args... ) -> f. grad (G, θ, p, args... )
24
126
@@ -53,7 +155,7 @@ function OptimizationBase.instantiate_function(f, x, adtype::AutoModelingToolkit
53
155
end
54
156
55
157
function OptimizationBase. instantiate_function (f, cache:: OptimizationBase.ReInitCache ,
56
- adtype:: AutoModelingToolkit , num_cons = 0 )
158
+ adtype:: AutoSymbolics , num_cons = 0 )
57
159
p = isnothing (cache. p) ? SciMLBase. NullParameters () : cache. p
58
160
59
161
sys = complete (ModelingToolkit. modelingtoolkitize (OptimizationProblem (f, cache. u0,
@@ -64,8 +166,8 @@ function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInit
64
166
num_cons))))
65
167
# sys = ModelingToolkit.structural_simplify(sys)
66
168
f = OptimizationProblem (sys, cache. u0, cache. p, grad = true , hess = true ,
67
- sparse = adtype . obj_sparse , cons_j = true , cons_h = true ,
68
- cons_sparse = adtype . cons_sparse ). f
169
+ sparse = false , cons_j = true , cons_h = true ,
170
+ cons_sparse = false ). f
69
171
70
172
grad = (G, θ, args... ) -> f. grad (G, θ, cache. p, args... )
71
173
0 commit comments