Skip to content

Commit 57a23f5

Browse files
committed
added update! method for sparsematrixlnk
fixed allocations in Base.+
1 parent ad12ebe commit 57a23f5

File tree

1 file changed

+76
-48
lines changed

1 file changed

+76
-48
lines changed

src/sparsematrixlnk.jl

Lines changed: 76 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -124,32 +124,62 @@ end
124124

125125

126126

127-
128-
129-
"""
130-
$(SIGNATURES)
131-
132-
Return value stored for entry or zero if not found
133-
"""
134-
function Base.getindex(lnk::SparseMatrixLNK{Tv,Ti},i::Integer, j::Integer) where {Tv,Ti<:Integer}
135-
127+
function findindex(lnk::SparseMatrixLNK{Tv,Ti},i::Integer, j::Integer) where {Tv,Ti<:Integer}
136128
if !((1 <= i <= lnk.m) & (1 <= j <= lnk.n))
137129
throw(BoundsError(lnk, (i,j)))
138130
end
139131

140132
k=j
133+
k0=j
141134
while k>0
142135
if lnk.rowval[k]==i
143-
return lnk.nzval[k]
136+
return k,0
144137
end
138+
k0=k
145139
k=lnk.colptr[k]
146140
end
141+
return 0,k0
142+
end
143+
144+
145+
"""
146+
$(SIGNATURES)
147+
148+
Return value stored for entry or zero if not found
149+
"""
150+
function Base.getindex(lnk::SparseMatrixLNK{Tv,Ti},i::Integer, j::Integer) where {Tv,Ti<:Integer}
147151

148-
return zero(Tv)
152+
k,k0=findindex(lnk,i,j)
153+
if k==0
154+
return zero(Tv)
155+
else
156+
return lnk.nzval[k]
157+
end
149158
end
150159

151160

161+
function addentry!(lnk::SparseMatrixLNK{Tv,Ti},i,j,k,k0) where {Tv,Ti<:Integer}
162+
# increase number of entries
163+
lnk.nentries+=1
164+
if length(lnk.nzval)<lnk.nentries
165+
newsize=Int64(ceil(5.0*lnk.nentries/4.0))
166+
resize!(lnk.nzval,newsize)
167+
resize!(lnk.rowval,newsize)
168+
resize!(lnk.colptr,newsize)
169+
end
170+
171+
# Append entry if not found
172+
# lnk.nzval[lnk.nentries]=v
173+
lnk.rowval[lnk.nentries]=i
152174

175+
# Shift the end of the list
176+
lnk.colptr[lnk.nentries]=0
177+
lnk.colptr[k0]=lnk.nentries
178+
179+
# Update number of nonzero entries
180+
lnk.nnz+=1
181+
return lnk.nentries
182+
end
153183

154184
"""
155185
$(SIGNATURES)
@@ -173,40 +203,38 @@ function Base.setindex!(lnk::SparseMatrixLNK{Tv,Ti}, _v, _i::Integer, _j::Intege
173203
return lnk
174204
end
175205

176-
# Traverse list for existing entry
177-
k=j
178-
k0=j
179-
while k>0
180-
# Update value and return if entry has been found
181-
if lnk.rowval[k]==i
182-
lnk.nzval[k]=v
183-
return lnk
184-
end
185-
k0=k
186-
# Next element in the list
187-
k=lnk.colptr[k]
206+
k,k0=findindex(lnk,i,j)
207+
if k>0
208+
lnk.nzval[k]=v
209+
return lnk
188210
end
211+
k=addentry!(lnk,i,j,k,k0)
212+
lnk.nzval[k]=v
213+
return lnk
214+
end
189215

190-
# increase number of entries
191-
lnk.nentries+=1
192-
if length(lnk.nzval)<lnk.nentries
193-
newsize=Int64(ceil(5.0*lnk.nentries/4.0))
194-
resize!(lnk.nzval,newsize)
195-
resize!(lnk.rowval,newsize)
196-
resize!(lnk.colptr,newsize)
197-
end
198-
199-
# Append entry if not found
200-
lnk.nzval[lnk.nentries]=v
201-
lnk.rowval[lnk.nentries]=i
202216

203-
# Shift the end of the list
204-
lnk.colptr[lnk.nentries]=0
205-
lnk.colptr[k0]=lnk.nentries
217+
function update!(lnk::SparseMatrixLNK{Tv,Ti}, _v, _i::Integer, _j::Integer,op) where {Tv,Ti<:Integer}
218+
v = convert(Tv, _v)
219+
i = convert(Ti, _i)
220+
j = convert(Ti, _j)
206221

207-
# Update number of nonzero entries
208-
lnk.nnz+=1
209-
return lnk
222+
223+
# Set the first column entry if it was not yet set.
224+
if lnk.rowval[j]==0
225+
lnk.rowval[j]=i
226+
lnk.nzval[j]=op(zero(Tv),v)
227+
lnk.nnz+=1
228+
return lnk
229+
end
230+
k,k0=findindex(lnk,i,j)
231+
if k>0
232+
lnk.nzval[k]=op(lnk.nzval[k],v)
233+
return lnk
234+
end
235+
k=addentry!(lnk,i,j,k,k0)
236+
lnk.nzval[k]=op(lnk.nzval[k],v)
237+
lnk
210238
end
211239

212240

@@ -254,17 +282,16 @@ Base.isless(x::ColEntry{Tv, Ti},y::ColEntry{Tv, Ti}) where {Tv,Ti<:Integer} = (x
254282
"""
255283
$(SIGNATURES)
256284
257-
Add SparseMatrixCSC matrix and [`SparseMatrixLNK`](@ref) lnk.
285+
Add SparseMatrixCSC matrix and [`SparseMatrixLNK`](@ref) lnk, returning a SparseMatrixCSC
258286
"""
259-
function Base.:+(lnk::SparseMatrixLNK{Tv,Ti},csc::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti<:Integer}
287+
function Base.:+(lnk::SparseMatrixLNK{Tv,Ti},csc::SparseMatrixCSC{Tv,Ti})::SparseMatrixCSC{Tv,Ti} where {Tv,Ti<:Integer}
260288
@assert(csc.m==lnk.m)
261289
@assert(csc.n==lnk.n)
262-
290+
263291
xnnz=nnz(csc)+nnz(lnk)
264292
colptr=Vector{Ti}(undef,csc.n+1)
265293
rowval=Vector{Ti}(undef,xnnz)
266294
nzval=Vector{Tv}(undef,xnnz)
267-
268295
# Detect the maximum column length of lnk
269296
lnk_maxcol=0
270297
for j=1:csc.n
@@ -284,9 +311,12 @@ function Base.:+(lnk::SparseMatrixLNK{Tv,Ti},csc::SparseMatrixCSC{Tv,Ti}) where
284311

285312
inz=1 # counts the nonzero entries in the new matrix
286313

314+
l_lnk_col=zero(Ti)
315+
in_csc_col(jcsc)=(nnz(csc)>zero(Ti)) && (jcsc<csc.colptr[j+1])
316+
in_lnk_col(jlnk)=(jlnk<=l_lnk_col)
317+
287318
# loop over all columns
288319
for j=1:csc.n
289-
290320
# Copy extension entries into col and sort them
291321
k=j
292322
l_lnk_col=zero(Ti)
@@ -305,8 +335,6 @@ function Base.:+(lnk::SparseMatrixLNK{Tv,Ti},csc::SparseMatrixCSC{Tv,Ti}) where
305335
jlnk=one(Ti) # counts the entries in col
306336
jcsc=csc.colptr[j] # counts entries in csc
307337

308-
in_csc_col(jcsc)=(nnz(csc)>zero(Ti)) && (jcsc<csc.colptr[j+1])
309-
in_lnk_col(jlnk)=(jlnk<=l_lnk_col)
310338

311339
while true
312340
if in_csc_col(jcsc) && (in_lnk_col(jlnk) && csc.rowval[jcsc]<col[jlnk].rowval || !in_lnk_col(jlnk))

0 commit comments

Comments
 (0)