Skip to content

Commit 91681aa

Browse files
committed
re-organized+added tests
bugfix in getindex
1 parent 6be6f6a commit 91681aa

File tree

6 files changed

+175
-102
lines changed

6 files changed

+175
-102
lines changed

examples/ExtendableSparseMatrixTest.jl

Lines changed: 0 additions & 95 deletions
This file was deleted.

src/ExtendableSparse.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ using LinearAlgebra
55

66
include("extension.jl")
77
include("extendable.jl")
8+
include("sprand.jl")
89

9-
export SparseMatrixExtension,ExtendableSparseMatrix,flush!,nnz
10+
export SparseMatrixExtension,ExtendableSparseMatrix,flush!,nnz,sprand!
1011

1112
export xcolptrs,colptrs
1213
end # module

src/extendable.jl

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ ExtendableSparseMatrix(::Type{Tv}, ::Type{Ti}, m::Integer, n::Integer) where {Tv
3232

3333

3434

35+
36+
37+
3538
"""
3639
$(TYPEDSIGNATURES)
3740
@@ -43,18 +46,21 @@ function findindex(S::SparseMatrixCSC{T}, i::Integer, j::Integer) where T
4346
r1 = Int(S.colptr[j])
4447
r2 = Int(S.colptr[j+1]-1)
4548
if r1>r2
46-
return 0
49+
return zero(T)
4750
end
4851

4952
# See sparsematrix.jl
5053
r1 = searchsortedfirst(S.rowval, i, r1, r2, Base.Forward)
51-
if (r1>length(S.rowval) ||S.rowval[r1] != i)
52-
return 0
54+
if (r1>r2 ||S.rowval[r1] != i)
55+
return zero(T)
5356
end
5457
return r1
5558
end
5659

5760

61+
62+
63+
5864
"""
5965
$(TYPEDSIGNATURES)
6066
@@ -70,6 +76,10 @@ function Base.setindex!(M::ExtendableSparseMatrix, v, i::Integer, j::Integer)
7076
end
7177
end
7278

79+
80+
81+
82+
7383
"""
7484
$(TYPEDSIGNATURES)
7585
@@ -86,6 +96,10 @@ function Base.getindex(M::ExtendableSparseMatrix,i::Integer, j::Integer)
8696
end
8797

8898

99+
100+
101+
102+
89103
"""
90104
$(TYPEDSIGNATURES)
91105
@@ -94,6 +108,10 @@ Matrix size.
94108
Base.size(E::ExtendableSparseMatrix) = (E.cscmatrix.m, E.cscmatrix.n)
95109

96110

111+
112+
113+
114+
97115
"""
98116
$(TYPEDSIGNATURES)
99117
@@ -103,6 +121,11 @@ SparseArrays.nnz(E::ExtendableSparseMatrix)=(nnz(E.cscmatrix)+nnz(E.extmatrix))
103121

104122

105123

124+
125+
126+
127+
128+
106129
# Struct holding pair of value and row
107130
# number, for sorting
108131
mutable struct ColEntry{Tv,Ti<:Integer}
@@ -114,7 +137,6 @@ end
114137
Base.isless(x::ColEntry{Tv, Ti},y::ColEntry{Tv, Ti}) where {Tv,Ti<:Integer} = (x.i<y.i)
115138

116139

117-
118140
function _splice(E::SparseMatrixExtension{Tv,Ti},S::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti<:Integer}
119141
# Create new CSC matrix with sorted entries from CSC matrix S and matrix extension E.
120142
#
@@ -130,7 +152,7 @@ function _splice(E::SparseMatrixExtension{Tv,Ti},S::SparseMatrixCSC{Tv,Ti}) wher
130152
rowval=Vector{Ti}(undef,xnnz)
131153
nzval=Vector{Tv}(undef,xnnz)
132154

133-
# Detect the maximum column lenght of E
155+
# Detect the maximum column length of E
134156
E_maxcol=0
135157
for j=1:S.n
136158
lcol=0
@@ -143,7 +165,7 @@ function _splice(E::SparseMatrixExtension{Tv,Ti},S::SparseMatrixCSC{Tv,Ti}) wher
143165
end
144166

145167
# pre-allocate column
146-
col=[ColEntry{Tv,Ti}(0,0) for i=1:E_maxcol+1]
168+
col=[ColEntry{Tv,Ti}(0,0) for i=1:E_maxcol]
147169

148170

149171

@@ -199,6 +221,9 @@ function _splice(E::SparseMatrixExtension{Tv,Ti},S::SparseMatrixCSC{Tv,Ti}) wher
199221
end
200222

201223

224+
225+
226+
202227
"""
203228
$(TYPEDSIGNATURES)
204229
@@ -215,6 +240,9 @@ end
215240

216241

217242

243+
244+
245+
218246
"""
219247
$(TYPEDSIGNATURES)
220248
@@ -225,6 +253,9 @@ function SparseArrays.nonzeros(E::ExtendableSparseMatrix)
225253
return nonzeros(E.cscmatrix)
226254
end
227255

256+
257+
258+
228259
"""
229260
$(TYPEDSIGNATURES)
230261
@@ -236,6 +267,9 @@ function SparseArrays.rowvals(E::ExtendableSparseMatrix)
236267
end
237268

238269

270+
271+
272+
239273
"""
240274
$(TYPEDSIGNATURES)
241275
@@ -246,6 +280,9 @@ function xcolptrs(E::ExtendableSparseMatrix)
246280
return E.cscmatrix.colptr
247281
end
248282

283+
284+
285+
249286
"""
250287
$(TYPEDSIGNATURES)
251288
@@ -257,6 +294,16 @@ function colptrs(E::ExtendableSparseMatrix)
257294
end
258295

259296

297+
"""
298+
$(TYPEDSIGNATURES)
299+
300+
Flush and delegate to cscmatrix.
301+
"""
302+
function SparseArrays.findnz(E::ExtendableSparseMatrix)
303+
flush!(E)
304+
return findnz(E.cscmatrix)
305+
end
306+
260307

261308
"""
262309
$(TYPEDSIGNATURES)
@@ -270,6 +317,8 @@ function LinearAlgebra.lu(E::ExtendableSparseMatrix)
270317
end
271318

272319

320+
321+
273322
"""
274323
$(TYPEDSIGNATURES)
275324

src/sprand.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""
2+
$(TYPEDSIGNATURES)
3+
4+
Fill empty sparse matrix A with random nonzero elements from interval [1,2]
5+
using incremental assembly.
6+
7+
"""
8+
function sprand!(A::AbstractSparseMatrix{Tv,Ti},xnnz::Int) where {Tv,Ti}
9+
m,n=size(A)
10+
for i=1:xnnz
11+
i=rand((1:m))
12+
j=rand((1:n))
13+
a=1.0+rand(Tv)
14+
A[i,j]+=a
15+
end
16+
end
17+

test/ExtendableSparseTest.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
module ExtendableSparseTest
2+
using ExtendableSparse
3+
using SparseArrays
4+
using Printf
5+
6+
7+
8+
function randtest!(A::AbstractSparseMatrix{Tv,Ti},xnnz::Int,nsplice::Int) where {Tv,Ti}
9+
m,n=size(A)
10+
S=spzeros(m,n)
11+
for isplice=1:nsplice
12+
for inz=1:xnnz
13+
i=rand((1:m))
14+
j=rand((1:n))
15+
a=1.0+rand(Float64)
16+
S[i,j]+=a
17+
A[i,j]+=a
18+
@assert(nnz(S)==nnz(A))
19+
end
20+
flush!(A)
21+
for j=1:n
22+
@assert(issorted(A.cscmatrix.rowval[A.cscmatrix.colptr[j]:A.cscmatrix.colptr[j+1]-1]))
23+
end
24+
@assert(nnz(S)==nnz(A))
25+
26+
(I,J,V)=findnz(S)
27+
for inz=1:nnz(S)
28+
@assert(A[I[inz],J[inz]]==V[inz])
29+
end
30+
31+
(I,J,V)=findnz(A)
32+
for inz=1:nnz(A)
33+
@assert(S[I[inz],J[inz]]==V[inz])
34+
end
35+
36+
end
37+
return true
38+
end
39+
40+
41+
42+
function check(;m=1000,n=1000,nnz=5000,nsplice=1)
43+
mat=ExtendableSparseMatrix(Float64,Int64,m,n)
44+
return randtest!(mat,nnz,nsplice)
45+
end
46+
47+
48+
49+
function benchmark(;n=10000,m=10000,nnz=50000)
50+
51+
println("SparseMatrixCSC:")
52+
mat=spzeros(Float64,Int64,m,n)
53+
@time sprand!(mat,nnz)
54+
55+
println("SparseMatrixExtension:")
56+
extmat=SparseMatrixExtension(Float64,Int64,m,n)
57+
@time sprand!(extmat,nnz)
58+
59+
println("ExtendableSparseMatrix:")
60+
xextmat=ExtendableSparseMatrix(Float64,Int64,m,n)
61+
@time begin
62+
sprand!(xextmat,nnz)
63+
@inbounds flush!(xextmat)
64+
end
65+
return
66+
b
67+
end
68+
69+
70+
end

0 commit comments

Comments
 (0)