Skip to content

Commit c89e9b6

Browse files
committed
Interface to LinearSolve
1 parent f5b5e80 commit c89e9b6

File tree

8 files changed

+99
-10
lines changed

8 files changed

+99
-10
lines changed

Project.toml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,30 @@
11
name = "ExtendableSparse"
22
uuid = "95c220a8-a1cf-11e9-0c77-dbfce5f500b3"
33
authors = ["Juergen Fuhrmann <juergen.fuhrmann@wias-berlin.de>"]
4-
version = "0.6.8"
4+
version = "0.7.0"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
8+
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
911
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1012
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
13+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
14+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1115
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1216
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
1317
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
18+
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
1419

1520
[compat]
1621
DocStringExtensions = "0.8.0,0.9"
22+
LinearSolve = "1.23"
1723
Requires = "^1.1.3"
18-
julia = "^1.5"
24+
Setfield = "0.7, 0.8, 1"
25+
SciMLBase="1.49"
26+
UnPack = "1"
27+
julia = "^1.6"
1928

2029
[extras]
2130
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

src/ExtendableSparse.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,28 @@ module ExtendableSparse
22
using SparseArrays
33
using LinearAlgebra
44
using SuiteSparse
5+
using Setfield: @set!
6+
using UnPack: @unpack
7+
import LinearSolve,SciMLBase
8+
59
using Requires
610

711
using DocStringExtensions
812

913
import SparseArrays: rowvals, getcolptr, nonzeros
10-
import Base: copy
14+
1115

1216
include("sparsematrixcsc.jl")
1317
include("sparsematrixlnk.jl")
1418
include("extendable.jl")
1519

1620
export SparseMatrixLNK,ExtendableSparseMatrix,flush!,nnz, updateindex!, rawupdateindex!, colptrs
1721

22+
23+
include("linearsolve.jl")
24+
25+
26+
1827
include("factorizations.jl")
1928
export JacobiPreconditioner, ILU0Preconditioner, ParallelJacobiPreconditioner, ParallelILU0Preconditioner, reorderlinsys
2029
export AbstractFactorization,LUFactorization, CholeskyFactorization

src/extendable.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ Create similar but emtpy extendableSparseMatrix
6262
"""
6363
Base.similar(m::ExtendableSparseMatrix{Tv,Ti}) where {Tv,Ti}=ExtendableSparseMatrix{Tv,Ti}(size(m)...)
6464

65+
Base.similar(m::ExtendableSparseMatrix{Tv,Ti},::Type{T}) where {Tv,Ti,T}=ExtendableSparseMatrix{T,Ti}(size(m)...)
6566

6667
"""
6768
$(SIGNATURES)
@@ -394,10 +395,11 @@ function SparseArrays.dropzeros!(ext::ExtendableSparseMatrix)
394395
end
395396

396397

397-
function copy(S::ExtendableSparseMatrix)
398+
function Base.copy(S::ExtendableSparseMatrix)
398399
if isnothing(S.lnkmatrix)
399400
ExtendableSparseMatrix(copy(S.cscmatrix), nothing, S.phash)
400401
else
401402
ExtendableSparseMatrix(copy(S.cscmatrix), copy(S.lnkmatrix), S.phash)
402403
end
403-
end
404+
end
405+

src/linearsolve.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""
2+
LinearSolve.LinearProblem(A::ExtendableSparseMatrix,x,p=SciMLBase.NullParameters();u0=nothing,kwargs...)
3+
4+
Create linear problem from ExtendableSparseMatrix. This uses the internal SparseMatrixCSC, and thus allows to access
5+
the functionality of [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl)
6+
"""
7+
function LinearSolve.LinearProblem(A::ExtendableSparseMatrix,x,p=SciMLBase.NullParameters();u0=nothing,kwargs...)
8+
flush!(A)
9+
LinearSolve.LinearProblem{false}(A.cscmatrix,x,p;u0,kwargs...)
10+
end
11+
12+
"""
13+
LinearSolve.set_A(cache::LinearSolve.LinearCache, Aext::ExtendableSparseMatrix)
14+
15+
Update the linear solve cache from an ExtendableSparseMatrix. Note that this update allows
16+
to take into account changes of the matrix pattern.
17+
"""
18+
function LinearSolve.set_A(cache::LinearSolve.LinearCache, Aext::ExtendableSparseMatrix)
19+
flush!(Aext)
20+
21+
@unpack alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose = cache
22+
23+
if !pattern_equal(Aext.cscmatrix,A)
24+
cacheval = LinearSolve.init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
25+
@set! cache.cacheval=cacheval
26+
@set! cache.isfresh = false
27+
end
28+
@set! cache.A = Aext.cscmatrix
29+
@set! cache.isfresh = true
30+
return cache
31+
end

src/sparsematrixcsc.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,11 @@ Hash of csc matrix pattern.
7575
phash(csc::SparseMatrixCSC)=hash((hash(csc.colptr),hash(csc.rowval)))
7676
# probably no good idea to add two hashes, so we hash them together.
7777

78+
79+
"""
80+
pattern_equal(a::SparseMatrixCSC,b::SparseMatrixCSC)
81+
82+
Check if sparsity patterns of two SparseMatrixCSC objects are equal.
83+
This is generally faster than comparing hashes.
84+
"""
85+
pattern_equal(a::SparseMatrixCSC,b::SparseMatrixCSC) = a.colptr==b.colptr && a.rowval==b.rowval

src/sparsematrixlnk.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,4 +415,4 @@ rowvals(S::SparseMatrixLNK) = getfield(S, :rowval)
415415
getcolptr(S::SparseMatrixLNK) = getfield(S, :colptr)
416416
nonzeros(S::SparseMatrixLNK) = getfield(S, :nzval)
417417

418-
copy(S::SparseMatrixLNK) = SparseMatrixLNK(size(S, 1), size(S, 2), S.nnz, S.nentries, copy(getcolptr(S)), copy(rowvals(S)), copy(nonzeros(S)))
418+
Base.copy(S::SparseMatrixLNK) = SparseMatrixLNK(size(S, 1), size(S, 2), S.nnz, S.nentries, copy(getcolptr(S)), copy(rowvals(S)), copy(nonzeros(S)))

test/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
44
IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895"
55
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
78
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
89
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
910
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1011
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1112

1213
[compat]
13-
AlgebraicMultigrid = "^0.4"
14+
AlgebraicMultigrid = "0.4,0.5"
1415
IncompleteLU = "^0.2"
1516
Pardiso = "^0.5.1"

test/runtests.jl

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ using AlgebraicMultigrid
1010
using IncompleteLU
1111
using IterativeSolvers
1212

13+
using LinearSolve
14+
1315
##############################################################
1416
@testset "Constructors" begin
1517
function test_constructors()
@@ -320,7 +322,7 @@ end
320322
@test test_hermitian(300,:L)
321323
end
322324

323-
function test_lu1(k,l,m; lufac=LUFactorization())
325+
function test_lu1(k,l,m; lufac=ExtendableSparse.LUFactorization())
324326
Acsc=fdrand(k,l,m,rand=()->1,matrixtype=SparseMatrixCSC)
325327
b=rand(k*l*m)
326328
LUcsc=lu(Acsc)
@@ -342,7 +344,7 @@ function test_lu1(k,l,m; lufac=LUFactorization())
342344
x1cscx1ext && x2csc x2ext
343345
end
344346

345-
function test_lu2(k,l,m;lufac=LUFactorization())
347+
function test_lu2(k,l,m;lufac=ExtendableSparse.LUFactorization())
346348
Aext=fdrand(k,l,m,rand=()->1,matrixtype=ExtendableSparseMatrix)
347349
b=rand(k*l*m)
348350
lu!(lufac,Aext)
@@ -357,7 +359,7 @@ function test_lu2(k,l,m;lufac=LUFactorization())
357359
end
358360

359361

360-
@testset "LUFactorization" begin
362+
@testset "ExtendableSparse.LUFactorization" begin
361363
@test test_lu1(10,10,10)
362364
@test test_lu1(25,40,1)
363365
@test test_lu1(1000,1,1)
@@ -432,3 +434,30 @@ end
432434
@test test_parilu0(1000)
433435
end
434436

437+
438+
function test_linearsolve(n)
439+
440+
A=fdrand(n,1,1, matrixtype=ExtendableSparseMatrix)
441+
b=rand(n)
442+
c=A\b
443+
@test c LinearSolve.solve(LinearProblem(A,b)).u
444+
445+
A=fdrand(n,n,1, matrixtype=ExtendableSparseMatrix)
446+
b=rand(n*n)
447+
c=A\b
448+
@test c LinearSolve.solve(LinearProblem(A,b)).u
449+
450+
@test c LinearSolve.solve(LinearProblem(A,b),UMFPACKFactorization()).u
451+
@test c LinearSolve.solve(LinearProblem(A,b),KLUFactorization()).u
452+
@test c LinearSolve.solve(LinearProblem(A,b),IterativeSolversJL_CG(),Pl=ILU0Preconditioner(A)).u
453+
@test c LinearSolve.solve(LinearProblem(A,b),IterativeSolversJL_CG(),Pl=JacobiPreconditioner(A)).u
454+
@test c LinearSolve.solve(LinearProblem(A,b),IterativeSolversJL_CG(),Pl=ParallelJacobiPreconditioner(A)).u
455+
@test c LinearSolve.solve(LinearProblem(A,b),IterativeSolversJL_CG(),Pl=ILUTPreconditioner(A)).u
456+
@test c LinearSolve.solve(LinearProblem(A,b),IterativeSolversJL_CG(),Pl=AMGPreconditioner(A)).u
457+
458+
459+
end
460+
461+
@testset "LinearSolve" begin
462+
test_linearsolve(20)
463+
end

0 commit comments

Comments
 (0)