Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit f040740

Browse files
committed
Make mutable pre-allocated Jacobian default oop autodiff method whenever J is mutable.
1 parent bb80238 commit f040740

File tree

1 file changed

+36
-17
lines changed

1 file changed

+36
-17
lines changed

src/differentiation/compute_jacobian_ad.jl

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,16 @@ end
6868
jac_prototype = nothing,
6969
chunksize = nothing,
7070
dx = sparsity === nothing && jac_prototype === nothing ? nothing : copy(x)) #if dx is nothing, we will estimate dx at the cost of a function call
71+
@show typeof(x)
72+
7173
if sparsity === nothing && jac_prototype === nothing || !ArrayInterface.ismutable(x)
7274
cfg = chunksize === nothing ? ForwardDiff.JacobianConfig(f, x) : ForwardDiff.JacobianConfig(f, x, ForwardDiff.Chunk(getsize(chunksize)))
7375
return ForwardDiff.jacobian(f, x, cfg)
7476
end
7577
if dx isa Nothing
7678
dx = f(x)
7779
end
80+
@show "Line 80"
7881
forwarddiff_color_jacobian(f,x,ForwardColorJacCache(f,x,chunksize,dx=dx,colorvec=colorvec,sparsity=sparsity),jac_prototype)
7982
end
8083

@@ -84,11 +87,12 @@ end
8487
sparsity = nothing,
8588
jac_prototype = nothing,
8689
chunksize = nothing,
87-
dx = similar(x, size(J, 1))) #if dx is nothing, we will estimate dx at the cost of a function call
90+
dx = similar(x, size(J, 1))) #dx kwarg can be used to avoid re-allocating dx every time
8891
if sparsity === nothing && jac_prototype === nothing || !ArrayInterface.ismutable(x)
8992
cfg = chunksize === nothing ? ForwardDiff.JacobianConfig(f, x) : ForwardDiff.JacobianConfig(f, x, ForwardDiff.Chunk(getsize(chunksize)))
9093
return ForwardDiff.jacobian(f, x, cfg)
9194
end
95+
@show "Line 95"
9296
forwarddiff_color_jacobian(J,f,x,ForwardColorJacCache(f,x,chunksize,dx=dx,colorvec=colorvec,sparsity=sparsity),jac_prototype)
9397
end
9498

@@ -99,10 +103,18 @@ function forwarddiff_color_jacobian(f,x::AbstractArray{<:Number},jac_cache::Forw
99103

100104
J = jac_prototype isa Nothing ? (sparsity isa Nothing ? false .* vec(dx) .* vecx' : zeros(eltype(x),size(sparsity))) : zero(jac_prototype)
101105

102-
forwarddiff_color_jacobian(J, f, x, jac_cache, jac_prototype)
106+
@show typeof(J)
107+
if ArrayInterface.ismutable(J) # Whenever J is mutable, we mutate it to avoid allocations
108+
@show "Line 108"
109+
forwarddiff_color_jacobian(J, f, x, jac_cache, jac_prototype)
110+
else
111+
@show "Line 111"
112+
forwarddiff_color_jacobian_immutable(J, f, x, jac_cache, jac_prototype)
113+
end
103114
end
104115

105-
function forwarddiff_color_jacobian(J::AbstractArray{<:Number},f,x::AbstractArray{<:Number},jac_cache::ForwardColorJacCache,jac_prototype=nothing)
116+
# When J is mutable, this version of forwarddiff_color_jacobian will mutate J to avoid allocations
117+
function forwarddiff_color_jacobian(J::AbstractMatrix{<:Number},f,x::AbstractArray{<:Number},jac_cache::ForwardColorJacCache)
106118
t = jac_cache.t
107119
dx = jac_cache.dx
108120
p = jac_cache.p
@@ -141,7 +153,11 @@ function forwarddiff_color_jacobian(J::AbstractArray{<:Number},f,x::AbstractArra
141153
cols_index_c = vcat(cols_index_c,zeros(Int,nrows-len_rows))[perm_rows]
142154
Ji = [j==cols_index_c[i] ? dx[i] : false for i in 1:nrows, j in 1:ncols]
143155
end
144-
J = J + Ji
156+
if j == 1 && i == 1
157+
J .= Ji # overwrite pre-allocated matrix
158+
else
159+
J .+= Ji
160+
end
145161
color_i += 1
146162
(color_i > maxcolor) && return J
147163
end
@@ -150,14 +166,19 @@ function forwarddiff_color_jacobian(J::AbstractArray{<:Number},f,x::AbstractArra
150166
col_index = (i-1)*chunksize + j
151167
(col_index > ncols) && return J
152168
Ji = mapreduce(i -> i==col_index ? partials.(vec(fx), j) : adapt(parameterless_type(J),zeros(eltype(J),nrows)), hcat, 1:ncols)
153-
J = J + (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
169+
if j == 1 && i == 1
170+
J .= (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) # overwrite pre-allocated matrix
171+
else
172+
J .+= (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
173+
end
154174
end
155175
end
156176
end
157177
J
158178
end
159179

160-
function forwarddiff_color_jacobian(J::SparseMatrixCSC{<:Number},f,x::AbstractArray{<:Number},jac_cache::ForwardColorJacCache,jac_prototype=nothing)
180+
# When J is immutable, this version of forwarddiff_color_jacobian will avoid mutating J
181+
function forwarddiff_color_jacobian_immutable(J::AbstractArray{<:Number},f,x::AbstractArray{<:Number},jac_cache::ForwardColorJacCache)
161182
t = jac_cache.t
162183
dx = jac_cache.dx
163184
p = jac_cache.p
@@ -187,13 +208,16 @@ function forwarddiff_color_jacobian(J::SparseMatrixCSC{<:Number},f,x::AbstractAr
187208
pick_inds = [i for i in 1:length(rows_index) if colorvec[cols_index[i]] == color_i]
188209
rows_index_c = rows_index[pick_inds]
189210
cols_index_c = cols_index[pick_inds]
190-
Ji = sparse(rows_index_c, cols_index_c, dx[rows_index_c],nrows,ncols)
191-
# J = J + Ji
192-
if j == 1 && i == 1
193-
J .= Ji # overwrite pre-allocated matrix
211+
if J isa SparseMatrixCSC
212+
Ji = sparse(rows_index_c, cols_index_c, dx[rows_index_c],nrows,ncols)
194213
else
195-
J .+= Ji
214+
len_rows = length(pick_inds)
215+
unused_rows = setdiff(1:nrows,rows_index_c)
216+
perm_rows = sortperm(vcat(rows_index_c,unused_rows))
217+
cols_index_c = vcat(cols_index_c,zeros(Int,nrows-len_rows))[perm_rows]
218+
Ji = [j==cols_index_c[i] ? dx[i] : false for i in 1:nrows, j in 1:ncols]
196219
end
220+
J = J + Ji
197221
color_i += 1
198222
(color_i > maxcolor) && return J
199223
end
@@ -202,12 +226,7 @@ function forwarddiff_color_jacobian(J::SparseMatrixCSC{<:Number},f,x::AbstractAr
202226
col_index = (i-1)*chunksize + j
203227
(col_index > ncols) && return J
204228
Ji = mapreduce(i -> i==col_index ? partials.(vec(fx), j) : adapt(parameterless_type(J),zeros(eltype(J),nrows)), hcat, 1:ncols)
205-
# J = J + (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
206-
if j == 1 && i == 1
207-
J .= (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) # overwrite pre-allocated matrix
208-
else
209-
J .+= (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
210-
end
229+
J = J + (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
211230
end
212231
end
213232
end

0 commit comments

Comments
 (0)