Skip to content

Commit 30fe88e

Browse files
committed
First steps
0 parents  commit 30fe88e

File tree

3 files changed

+250
-0
lines changed

3 files changed

+250
-0
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# ExtendableSparse
2+
3+
Sparse matrix class which allows cheaper assembly by using
4+
a different data structure for extension.
5+
6+
Still work in progress.
7+
8+
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
module ExtendableSparseTest
2+
3+
using ExtendableSparse
4+
using SparseArrays
5+
using Printf
6+
7+
import PyPlot
8+
9+
function randmatx!(A::AbstractSparseMatrix{Tv,Ti}, m::Ti, n::Ti,xnnz::Ti) where {Tv,Ti}
10+
for inz=1:xnnz
11+
i=rand((1:m))
12+
j=rand((1:n))
13+
a=1.0+rand(Float64)
14+
A[i,j]+=a
15+
end
16+
17+
end
18+
19+
function testmatx!(A::AbstractSparseMatrix{Tv,Ti}, m::Ti, n::Ti, xnnz::Ti) where {Tv,Ti}
20+
S=spzeros(m,n)
21+
for inz=1:xnnz
22+
i=rand((1:m))
23+
j=rand((1:n))
24+
a=1.0+rand(Float64)
25+
@printf("%d %d %e\n",i,j,a)
26+
S[i,j]+=a
27+
A[i,j]+=a
28+
@printf("nnz(S)=%d nnz(A)=%d\n",nnz(S),nnz(A))
29+
@assert(nnz(S)==nnz(A))
30+
end
31+
@printf("nnz(S)=%d nnz(A)=%d\n",nnz(S),nnz(A))
32+
@assert(nnz(S)==nnz(A))
33+
(I,J,V)=findnz(S)
34+
for inz=1:nnz(S)
35+
@assert(A[I[inz],J[inz]]==V[inz])
36+
end
37+
end
38+
39+
function check(;m=10,n=10,nnz=3)
40+
extmat1=SparseMatrixExtension{Float64,Int64}(m,n)
41+
testmatx!(extmat1,m,n,nnz)
42+
end
43+
44+
45+
function benchmark(;n=100,m=100,nnz=500, pyplot=false)
46+
47+
mat=spzeros(Float64,Int64,n,n)
48+
@time randmatx!(mat,m,n,nnz)
49+
50+
extmat=SparseMatrixExtension{Float64,Int64}(n,n)
51+
@time randmatx!(extmat,m,n,nnz)
52+
53+
54+
xextmat=ExtendableSparseMatrixCSC(n,n,spzeros(Float64,Int64,n,n),SparseMatrixExtension{Float64,Int64}(n,n))
55+
@time randmatx!(xextmat,m,n,nnz)
56+
57+
58+
59+
60+
if pyplot
61+
PyPlot.clf()
62+
PyPlot.spy(extmat,markersize=1)
63+
PyPlot.show()
64+
end
65+
# for i=1:n
66+
# println(rand(1:n))
67+
# end
68+
end
69+
70+
end

src/ExtendableSparse.jl

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
module ExtendableSparse
2+
using SparseArrays
3+
4+
mutable struct SparseMatrixExtension{Tv,Ti<:Integer} <: AbstractSparseMatrix{Tv,Ti}
5+
m::Ti
6+
n::Ti
7+
nnz::Ti
8+
colptr::Vector{Ti}
9+
rowval::Vector{Ti}
10+
nzval::Vector{Tv}
11+
12+
function SparseMatrixExtension{Tv,Ti}(m,n) where {Tv,Ti<:Integer}
13+
colptr=zeros(Ti,n)
14+
rowval=zeros(Ti,n)
15+
nzval=zeros(Tv,n)
16+
new(m,n,0,zeros(Ti,n),zeros(Ti,n),zeros(Tv,n))
17+
end
18+
end
19+
20+
function Base.setindex!(E::SparseMatrixExtension{Tv,Ti}, _v, _i::Integer, _j::Integer) where {Tv,Ti<:Integer}
21+
v = convert(Tv, _v)
22+
i = convert(Ti, _i)
23+
j = convert(Ti, _j)
24+
25+
if !((1 <= i <= E.m) & (1 <= j <= E.n))
26+
throw(BoundsError(E, (i,j)))
27+
end
28+
29+
if iszero(v)
30+
return E
31+
end
32+
33+
if E.rowval[j]==0
34+
E.rowval[j]=i
35+
E.nzval[j]=v
36+
E.nnz+=1
37+
return E
38+
end
39+
40+
k=j
41+
k0=j
42+
while k>0
43+
if E.rowval[k]==i
44+
E.nzval[k]=v
45+
return E
46+
end
47+
k0=k
48+
k=E.colptr[k]
49+
end
50+
push!(E.nzval,v)
51+
push!(E.rowval,i)
52+
push!(E.colptr,-1)
53+
E.colptr[k0]=length(E.nzval)
54+
E.nnz+=1
55+
return E
56+
end
57+
58+
function Base.getindex(E::SparseMatrixExtension{Tv,Ti},i::Integer, j::Integer) where {Tv,Ti<:Integer}
59+
if !((1 <= i <= E.m) & (1 <= j <= E.n))
60+
throw(BoundsError(E, (i,j)))
61+
end
62+
63+
k=j
64+
while k>0
65+
if E.rowval[k]==i
66+
return E.nzval[k]
67+
end
68+
k=E.colptr[k]
69+
end
70+
return zero(Tv)
71+
end
72+
73+
Base.size(E::SparseMatrixExtension) = (E.m, E.n)
74+
SparseArrays.nnz(E::SparseMatrixExtension)=E.nnz
75+
76+
77+
78+
#####################################################################################################
79+
mutable struct ExtendableSparseMatrixCSC{Tv,Ti<:Integer} <: AbstractSparseMatrix{Tv,Ti}
80+
m::Int
81+
n::Int
82+
cscmatrix::SparseMatrixCSC{Tv,Ti}
83+
extmatrix::SparseMatrixExtension{Tv,Ti}
84+
end
85+
86+
function Base.setindex!(m::ExtendableSparseMatrixCSC, v, i::Integer, j::Integer)
87+
setindex!(m.cscmatrix,v,i,j)
88+
end
89+
90+
function Base.getindex(m::ExtendableSparseMatrixCSC,i::Integer, j::Integer)
91+
return getindex(m.cscmatrix,i,j)
92+
end
93+
94+
struct ColEntry{Tv,Ti<:Integer}
95+
j::Ti
96+
v::Tv
97+
end
98+
99+
isless(x::ColEntry,y::ColEntry)=(x.i<y.i)
100+
101+
function splice(E::ExtendableSparseMatrixCSC,S::SparseMatrixCSC)
102+
103+
104+
nnz=nnz(S)+nnz(E)
105+
colptr=Vector{Ti}(undef,S.m+1)
106+
rowval=Vector{Ti}(undef,nnz)
107+
nzval=Vector{Tv}(undef,nnz)
108+
109+
E_maxcol_ext=0
110+
S_maxcol=0
111+
for j=1:m
112+
lrow=0
113+
k=j
114+
while k>0
115+
lrow+=1
116+
k=E.colptr[k]
117+
end
118+
E_maxcol_ext=max(lrow,E_maxcol)
119+
S_maxcol=max(S.colptr[j+1]-S.colptr[j],S_maxcol)
120+
end
121+
122+
col=Vector{ColEntry}(undef,E_maxcol+S_Maxcol+10)
123+
maxcol=0
124+
125+
i=1
126+
for j=1:m
127+
# put extension entries into row and sort them
128+
k=j
129+
while k>0
130+
if colptr[k]>0
131+
lxcol+=1
132+
col[lxcol].j=colptr[k]
133+
col[lxcol].v=nzval[k]
134+
k=colptr[k]
135+
end
136+
end
137+
138+
sort!(col,lt=isless)
139+
# jointly sort old and mew entries into colptr
140+
141+
i0=i
142+
colptr[j]=i
143+
jcol=0
144+
k=S.colptr[j]
145+
while true
146+
if k<S.colptr[j+1] && icol>lxcol || S.colptr[k]<col[jcol].j
147+
rowval[i]=S.rowval[k]
148+
nzval[i]=S.nzval[k]
149+
k+=1
150+
i+=1
151+
continue
152+
end
153+
if jcol<lxcol
154+
rowval[i]=col[jcol].j
155+
nzval[i]=col[jcol].v
156+
jcol+=1
157+
i++
158+
continue
159+
end
160+
break
161+
end
162+
end
163+
maxrow=max(maxrow,i-i0)
164+
colptr[j]=i
165+
return SparseMatrixCSC(m,n,colptr,rowval,nzval)
166+
end
167+
168+
169+
170+
171+
export SparseMatrixExtension,ExtendableSparseMatrixCSC
172+
end # module

0 commit comments

Comments
 (0)