68
68
jac_prototype = nothing ,
69
69
chunksize = nothing ,
70
70
dx = sparsity === nothing && jac_prototype === nothing ? nothing : copy (x)) # if dx is nothing, we will estimate dx at the cost of a function call
71
+ @show typeof (x)
72
+
71
73
if sparsity === nothing && jac_prototype === nothing || ! ArrayInterface. ismutable (x)
72
74
cfg = chunksize === nothing ? ForwardDiff. JacobianConfig (f, x) : ForwardDiff. JacobianConfig (f, x, ForwardDiff. Chunk (getsize (chunksize)))
73
75
return ForwardDiff. jacobian (f, x, cfg)
74
76
end
75
77
if dx isa Nothing
76
78
dx = f (x)
77
79
end
80
+ @show " Line 80"
78
81
forwarddiff_color_jacobian (f,x,ForwardColorJacCache (f,x,chunksize,dx= dx,colorvec= colorvec,sparsity= sparsity),jac_prototype)
79
82
end
80
83
84
87
sparsity = nothing ,
85
88
jac_prototype = nothing ,
86
89
chunksize = nothing ,
87
- dx = similar (x, size (J, 1 ))) # if dx is nothing, we will estimate dx at the cost of a function call
90
+ dx = similar (x, size (J, 1 ))) # dx kwarg can be used to avoid re-allocating dx every time
88
91
if sparsity === nothing && jac_prototype === nothing || ! ArrayInterface. ismutable (x)
89
92
cfg = chunksize === nothing ? ForwardDiff. JacobianConfig (f, x) : ForwardDiff. JacobianConfig (f, x, ForwardDiff. Chunk (getsize (chunksize)))
90
93
return ForwardDiff. jacobian (f, x, cfg)
91
94
end
95
+ @show " Line 95"
92
96
forwarddiff_color_jacobian (J,f,x,ForwardColorJacCache (f,x,chunksize,dx= dx,colorvec= colorvec,sparsity= sparsity),jac_prototype)
93
97
end
94
98
@@ -99,10 +103,18 @@ function forwarddiff_color_jacobian(f,x::AbstractArray{<:Number},jac_cache::Forw
99
103
100
104
J = jac_prototype isa Nothing ? (sparsity isa Nothing ? false .* vec (dx) .* vecx' : zeros (eltype (x),size (sparsity))) : zero (jac_prototype)
101
105
102
- forwarddiff_color_jacobian (J, f, x, jac_cache, jac_prototype)
106
+ @show typeof (J)
107
+ if ArrayInterface. ismutable (J) # Whenever J is mutable, we mutate it to avoid allocations
108
+ @show " Line 108"
109
+ forwarddiff_color_jacobian (J, f, x, jac_cache, jac_prototype)
110
+ else
111
+ @show " Line 111"
112
+ forwarddiff_color_jacobian_immutable (J, f, x, jac_cache, jac_prototype)
113
+ end
103
114
end
104
115
105
- function forwarddiff_color_jacobian (J:: AbstractArray{<:Number} ,f,x:: AbstractArray{<:Number} ,jac_cache:: ForwardColorJacCache ,jac_prototype= nothing )
116
+ # When J is mutable, this version of forwarddiff_color_jacobian will mutate J to avoid allocations
117
+ function forwarddiff_color_jacobian (J:: AbstractMatrix{<:Number} ,f,x:: AbstractArray{<:Number} ,jac_cache:: ForwardColorJacCache )
106
118
t = jac_cache. t
107
119
dx = jac_cache. dx
108
120
p = jac_cache. p
@@ -141,7 +153,11 @@ function forwarddiff_color_jacobian(J::AbstractArray{<:Number},f,x::AbstractArra
141
153
cols_index_c = vcat (cols_index_c,zeros (Int,nrows- len_rows))[perm_rows]
142
154
Ji = [j== cols_index_c[i] ? dx[i] : false for i in 1 : nrows, j in 1 : ncols]
143
155
end
144
- J = J + Ji
156
+ if j == 1 && i == 1
157
+ J .= Ji # overwrite pre-allocated matrix
158
+ else
159
+ J .+ = Ji
160
+ end
145
161
color_i += 1
146
162
(color_i > maxcolor) && return J
147
163
end
@@ -150,14 +166,19 @@ function forwarddiff_color_jacobian(J::AbstractArray{<:Number},f,x::AbstractArra
150
166
col_index = (i- 1 )* chunksize + j
151
167
(col_index > ncols) && return J
152
168
Ji = mapreduce (i -> i== col_index ? partials .(vec (fx), j) : adapt (parameterless_type (J),zeros (eltype (J),nrows)), hcat, 1 : ncols)
153
- J = J + (size (Ji)!= size (J) ? reshape (Ji,size (J)) : Ji) # branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
169
+ if j == 1 && i == 1
170
+ J .= (size (Ji)!= size (J) ? reshape (Ji,size (J)) : Ji) # overwrite pre-allocated matrix
171
+ else
172
+ J .+ = (size (Ji)!= size (J) ? reshape (Ji,size (J)) : Ji) # branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
173
+ end
154
174
end
155
175
end
156
176
end
157
177
J
158
178
end
159
179
160
- function forwarddiff_color_jacobian (J:: SparseMatrixCSC{<:Number} ,f,x:: AbstractArray{<:Number} ,jac_cache:: ForwardColorJacCache ,jac_prototype= nothing )
180
+ # When J is immutable, this version of forwarddiff_color_jacobian will avoid mutating J
181
+ function forwarddiff_color_jacobian_immutable (J:: AbstractArray{<:Number} ,f,x:: AbstractArray{<:Number} ,jac_cache:: ForwardColorJacCache )
161
182
t = jac_cache. t
162
183
dx = jac_cache. dx
163
184
p = jac_cache. p
@@ -187,13 +208,16 @@ function forwarddiff_color_jacobian(J::SparseMatrixCSC{<:Number},f,x::AbstractAr
187
208
pick_inds = [i for i in 1 : length (rows_index) if colorvec[cols_index[i]] == color_i]
188
209
rows_index_c = rows_index[pick_inds]
189
210
cols_index_c = cols_index[pick_inds]
190
- Ji = sparse (rows_index_c, cols_index_c, dx[rows_index_c],nrows,ncols)
191
- # J = J + Ji
192
- if j == 1 && i == 1
193
- J .= Ji # overwrite pre-allocated matrix
211
+ if J isa SparseMatrixCSC
212
+ Ji = sparse (rows_index_c, cols_index_c, dx[rows_index_c],nrows,ncols)
194
213
else
195
- J .+ = Ji
214
+ len_rows = length (pick_inds)
215
+ unused_rows = setdiff (1 : nrows,rows_index_c)
216
+ perm_rows = sortperm (vcat (rows_index_c,unused_rows))
217
+ cols_index_c = vcat (cols_index_c,zeros (Int,nrows- len_rows))[perm_rows]
218
+ Ji = [j== cols_index_c[i] ? dx[i] : false for i in 1 : nrows, j in 1 : ncols]
196
219
end
220
+ J = J + Ji
197
221
color_i += 1
198
222
(color_i > maxcolor) && return J
199
223
end
@@ -202,12 +226,7 @@ function forwarddiff_color_jacobian(J::SparseMatrixCSC{<:Number},f,x::AbstractAr
202
226
col_index = (i- 1 )* chunksize + j
203
227
(col_index > ncols) && return J
204
228
Ji = mapreduce (i -> i== col_index ? partials .(vec (fx), j) : adapt (parameterless_type (J),zeros (eltype (J),nrows)), hcat, 1 : ncols)
205
- # J = J + (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
206
- if j == 1 && i == 1
207
- J .= (size (Ji)!= size (J) ? reshape (Ji,size (J)) : Ji) # overwrite pre-allocated matrix
208
- else
209
- J .+ = (size (Ji)!= size (J) ? reshape (Ji,size (J)) : Ji) # branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
210
- end
229
+ J = J + (size (Ji)!= size (J) ? reshape (Ji,size (J)) : Ji) # branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
211
230
end
212
231
end
213
232
end
0 commit comments