From 521554a36f00f6d7adb266308edfafea1abc8413 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 14 Oct 2025 04:05:48 -0400 Subject: [PATCH 1/5] Try integrating with the GPUArrays sparse migration --- Project.toml | 3 + lib/cusparse/CUSPARSE.jl | 2 - lib/cusparse/array.jl | 85 +++-- lib/cusparse/broadcast.jl | 669 -------------------------------------- lib/cusparse/device.jl | 140 +------- lib/cusparse/reduce.jl | 122 ------- src/CUDAKernels.jl | 10 +- test/Project.toml | 3 +- 8 files changed, 70 insertions(+), 964 deletions(-) delete mode 100644 lib/cusparse/broadcast.jl delete mode 100644 lib/cusparse/reduce.jl diff --git a/Project.toml b/Project.toml index c0fb240200..94e51a4e97 100644 --- a/Project.toml +++ b/Project.toml @@ -96,3 +96,6 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" + +[sources] +GPUArrays = {url = "https://github.com/JuliaGPU/GPUArrays.jl", rev = "ksh/sparse"} diff --git a/lib/cusparse/CUSPARSE.jl b/lib/cusparse/CUSPARSE.jl index e497710bea..ae1bdf5fb8 100644 --- a/lib/cusparse/CUSPARSE.jl +++ b/lib/cusparse/CUSPARSE.jl @@ -49,8 +49,6 @@ include("interfaces.jl") # native functionality include("device.jl") -include("broadcast.jl") -include("reduce.jl") include("batched.jl") diff --git a/lib/cusparse/array.jl b/lib/cusparse/array.jl index 55693e38b4..94d32f084f 100644 --- a/lib/cusparse/array.jl +++ b/lib/cusparse/array.jl @@ -8,13 +8,9 @@ export CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixBSR, CuSparseMatrixCO CuSparseVecOrMat using LinearAlgebra: BlasFloat -using SparseArrays: nonzeroinds, dimlub - -abstract type AbstractCuSparseArray{Tv, Ti, N} <: AbstractSparseArray{Tv, Ti, N} end -const AbstractCuSparseVector{Tv, Ti} = AbstractCuSparseArray{Tv, Ti, 1} -const AbstractCuSparseMatrix{Tv, Ti} = AbstractCuSparseArray{Tv, Ti, 2} - -Base.convert(T::Type{<:AbstractCuSparseArray}, m::AbstractArray) = m isa T ? m : T(m) +using SparseArrays +abstract type AbstractCuSparseVector{Tv, Ti} <: GPUArrays.AbstractGPUSparseArray{Tv, Ti, 1} end +abstract type AbstractCuSparseMatrix{Tv, Ti} <: GPUArrays.AbstractGPUSparseArray{Tv, Ti, 2} end mutable struct CuSparseVector{Tv, Ti} <: AbstractCuSparseVector{Tv, Ti} iPtr::CuVector{Ti} @@ -34,7 +30,7 @@ function CUDA.unsafe_free!(xs::CuSparseVector) return end -mutable struct CuSparseMatrixCSC{Tv, Ti} <: AbstractCuSparseMatrix{Tv, Ti} +mutable struct CuSparseMatrixCSC{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSC{Tv, Ti} colPtr::CuVector{Ti} rowVal::CuVector{Ti} nzVal::CuVector{Tv} @@ -47,6 +43,11 @@ mutable struct CuSparseMatrixCSC{Tv, Ti} <: AbstractCuSparseMatrix{Tv, Ti} end end +SparseArrays.rowvals(g::T) where {T<:CuSparseVector} = nonzeroinds(g) + +SparseArrays.rowvals(g::CuSparseMatrixCSC) = g.rowVal +SparseArrays.getcolptr(S::CuSparseMatrixCSC) = S.colPtr + CuSparseMatrixCSC(A::CuSparseMatrixCSC) = A function CUDA.unsafe_free!(xs::CuSparseMatrixCSC) @@ -69,7 +70,7 @@ GPU. !!! compat "CUDA 11" Support of indices type rather than `Cint` (`Int32`) requires at least CUDA 11. """ -mutable struct CuSparseMatrixCSR{Tv, Ti} <: AbstractCuSparseMatrix{Tv, Ti} +mutable struct CuSparseMatrixCSR{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSR{Tv, Ti} rowPtr::CuVector{Ti} colVal::CuVector{Ti} nzVal::CuVector{Tv} @@ -91,6 +92,22 @@ function CUDA.unsafe_free!(xs::CuSparseMatrixCSR) return end +GPUArrays._sparse_array_type(sa::CuSparseMatrixCSC) = CuSparseMatrixCSC +GPUArrays._sparse_array_type(::Type{<:CuSparseMatrixCSC}) = CuSparseMatrixCSC +GPUArrays._sparse_array_type(sa::CuSparseMatrixCSR) = CuSparseMatrixCSR +GPUArrays._sparse_array_type(::Type{<:CuSparseMatrixCSR}) = CuSparseMatrixCSR +GPUArrays._sparse_array_type(sa::CuSparseVector) = CuSparseVector +GPUArrays._sparse_array_type(::Type{<:CuSparseVector}) = CuSparseVector +GPUArrays._dense_array_type(sa::CuSparseVector) = CuArray +GPUArrays._dense_array_type(::Type{<:CuSparseVector}) = CuArray +GPUArrays._dense_array_type(sa::CuSparseMatrixCSC) = CuArray +GPUArrays._dense_array_type(::Type{<:CuSparseMatrixCSC}) = CuArray +GPUArrays._dense_array_type(sa::CuSparseMatrixCSR) = CuArray +GPUArrays._dense_array_type(::Type{<:CuSparseMatrixCSR}) = CuArray + +GPUArrays._csc_type(sa::CuSparseMatrixCSR) = CuSparseMatrixCSC +GPUArrays._csr_type(sa::CuSparseMatrixCSC) = CuSparseMatrixCSR + """ Container to hold sparse matrices in block compressed sparse row (BSR) format on the GPU. BSR format is also used in Intel MKL, and is suited to matrices that are @@ -142,7 +159,7 @@ end CuSparseMatrixCOO(A::CuSparseMatrixCOO) = A -mutable struct CuSparseArrayCSR{Tv, Ti, N} <: AbstractCuSparseArray{Tv, Ti, N} +mutable struct CuSparseArrayCSR{Tv, Ti, N} <: GPUArrays.AbstractGPUSparseArray{Tv, Ti, N} rowPtr::CuArray{Ti} colVal::CuArray{Ti} nzVal::CuArray{Tv} @@ -308,15 +325,6 @@ end ## sparse array interface -SparseArrays.nnz(g::AbstractCuSparseArray) = g.nnz -SparseArrays.nonzeros(g::AbstractCuSparseArray) = g.nzVal - -SparseArrays.nonzeroinds(g::AbstractCuSparseVector) = g.iPtr -SparseArrays.rowvals(g::AbstractCuSparseVector) = nonzeroinds(g) - -SparseArrays.rowvals(g::CuSparseMatrixCSC) = g.rowVal -SparseArrays.getcolptr(S::CuSparseMatrixCSC) = S.colPtr - function SparseArrays.findnz(S::MT) where {MT <: AbstractCuSparseMatrix} S2 = CuSparseMatrixCOO(S) I = S2.rowInd @@ -570,8 +578,8 @@ CuSparseMatrixCSC(x::Adjoint{T,<:Union{CuSparseMatrixCSC, CuSparseMatrixCSR, CuS CuSparseMatrixCOO(x::Adjoint{T,<:Union{CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO}}) where {T} = CuSparseMatrixCOO(_spadjoint(parent(x))) # gpu to cpu -SparseArrays.SparseVector(x::CuSparseVector) = SparseVector(length(x), Array(nonzeroinds(x)), Array(nonzeros(x))) -SparseArrays.SparseMatrixCSC(x::CuSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(x.colPtr), Array(rowvals(x)), Array(nonzeros(x))) +SparseArrays.SparseVector(x::CuSparseVector) = SparseVector(length(x), Array(SparseArrays.nonzeroinds(x)), Array(SparseArrays.nonzeros(x))) +SparseArrays.SparseMatrixCSC(x::CuSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(x.colPtr), Array(SparseArrays.rowvals(x)), Array(SparseArrays.nonzeros(x))) SparseArrays.SparseMatrixCSC(x::CuSparseMatrixCSR) = SparseMatrixCSC(CuSparseMatrixCSC(x)) # no direct conversion (gpu_CSR -> gpu_CSC -> cpu_CSC) SparseArrays.SparseMatrixCSC(x::CuSparseMatrixBSR) = SparseMatrixCSC(CuSparseMatrixCSR(x)) # no direct conversion (gpu_BSR -> gpu_CSR -> gpu_CSC -> cpu_CSC) SparseArrays.SparseMatrixCSC(x::CuSparseMatrixCOO) = SparseMatrixCSC(CuSparseMatrixCSC(x)) # no direct conversion (gpu_COO -> gpu_CSC -> cpu_CSC) @@ -729,16 +737,31 @@ end # interop with device arrays +function GPUArrays.GPUSparseDeviceVector(iPtr::CuDeviceVector{Ti, A}, + nzVal::CuDeviceVector{Tv, A}, + len::Int, + nnz::Ti) where {Ti, Tv, A} + GPUArrays.GPUSparseDeviceVector{Tv, Ti, CuDeviceVector{Ti, A}, CuDeviceVector{Tv, A}, A}(iPtr, nzVal, len, nnz) +end + function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseVector) - return CuSparseDeviceVector( + return GPUArrays.GPUSparseDeviceVector( adapt(to, x.iPtr), adapt(to, x.nzVal), length(x), x.nnz ) end +function GPUArrays.GPUSparseDeviceMatrixCSR(rowPtr::CuDeviceVector{Ti, A}, + colVal::CuDeviceVector{Ti, A}, + nzVal::CuDeviceVector{Tv, A}, + dims::NTuple{2, Int}, + nnz::Ti) where {Ti, Tv, A} + GPUArrays.GPUSparseDeviceMatrixCSR{Tv, Ti, CuDeviceVector{Ti, A}, CuDeviceVector{Tv, A}, A}(rowPtr, colVal, nzVal, dims, nnz) +end + function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixCSR) - return CuSparseDeviceMatrixCSR( + return GPUArrays.GPUSparseDeviceMatrixCSR( adapt(to, x.rowPtr), adapt(to, x.colVal), adapt(to, x.nzVal), @@ -746,8 +769,16 @@ function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixCSR) ) end +function GPUArrays.GPUSparseDeviceMatrixCSC(colPtr::CuDeviceVector{Ti, A}, + rowVal::CuDeviceVector{Ti, A}, + nzVal::CuDeviceVector{Tv, A}, + dims::NTuple{2, Int}, + nnz::Ti) where {Ti, Tv, A} + GPUArrays.GPUSparseDeviceMatrixCSC{Tv, Ti, CuDeviceVector{Ti, A}, CuDeviceVector{Tv, A}, A}(colPtr, rowVal, nzVal, dims, nnz) +end + function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixCSC) - return CuSparseDeviceMatrixCSC( + return GPUArrays.GPUSparseDeviceMatrixCSC( adapt(to, x.colPtr), adapt(to, x.rowVal), adapt(to, x.nzVal), @@ -756,7 +787,7 @@ function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixCSC) end function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixBSR) - return CuSparseDeviceMatrixBSR( + return GPUArrays.GPUSparseDeviceMatrixBSR( adapt(to, x.rowPtr), adapt(to, x.colVal), adapt(to, x.nzVal), @@ -766,7 +797,7 @@ function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixBSR) end function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixCOO) - return CuSparseDeviceMatrixCOO( + return GPUArrays.GPUSparseDeviceMatrixCOO( adapt(to, x.rowInd), adapt(to, x.colInd), adapt(to, x.nzVal), @@ -775,7 +806,7 @@ function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixCOO) end function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseArrayCSR) - return CuSparseDeviceArrayCSR( + return GPUArrays.GPUSparseDeviceArrayCSR( adapt(to, x.rowPtr), adapt(to, x.colVal), adapt(to, x.nzVal), diff --git a/lib/cusparse/broadcast.jl b/lib/cusparse/broadcast.jl deleted file mode 100644 index dff155c7bb..0000000000 --- a/lib/cusparse/broadcast.jl +++ /dev/null @@ -1,669 +0,0 @@ -using Base.Broadcast: Broadcasted - -using CUDA: CuArrayStyle - -# TODO: support more types (SparseVector, SparseMatrixCSC, COO, BSR) - - -## sparse broadcast style - -# broadcast container type promotion for combinations of sparse arrays and other types -struct CuSparseVecStyle <: Broadcast.AbstractArrayStyle{1} end -struct CuSparseMatStyle <: Broadcast.AbstractArrayStyle{2} end -Broadcast.BroadcastStyle(::Type{<:CuSparseVector}) = CuSparseVecStyle() -Broadcast.BroadcastStyle(::Type{<:CuSparseMatrix}) = CuSparseMatStyle() -const SPVM = Union{CuSparseVecStyle,CuSparseMatStyle} - -# CuSparseVecStyle handles 0-1 dimensions, CuSparseMatStyle 0-2 dimensions. -# CuSparseVecStyle promotes to CuSparseMatStyle for 2 dimensions. -# Fall back to DefaultArrayStyle for higher dimensionality. -CuSparseVecStyle(::Val{0}) = CuSparseVecStyle() -CuSparseVecStyle(::Val{1}) = CuSparseVecStyle() -CuSparseVecStyle(::Val{2}) = CuSparseMatStyle() -CuSparseVecStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() -CuSparseMatStyle(::Val{0}) = CuSparseMatStyle() -CuSparseMatStyle(::Val{1}) = CuSparseMatStyle() -CuSparseMatStyle(::Val{2}) = CuSparseMatStyle() -CuSparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() - -Broadcast.BroadcastStyle(::CuSparseVecStyle, ::CuArrayStyle{1}) = CuSparseVecStyle() -Broadcast.BroadcastStyle(::CuSparseVecStyle, ::CuArrayStyle{2}) = CuSparseMatStyle() -Broadcast.BroadcastStyle(::CuSparseMatStyle, ::CuArrayStyle{2}) = CuSparseMatStyle() - -# don't wrap sparse arrays with Extruded -Broadcast.extrude(x::CuSparseVecOrMat) = x - - -## detection of zero-preserving functions - -# modified from SparseArrays.jl - -# capturescalars takes a function (f) and a tuple of broadcast arguments, and returns a -# partially-evaluated function and a reduced argument tuple where all scalar operations have -# been applied already. -@inline function capturescalars(f, mixedargs) - let (passedsrcargstup, makeargs) = _capturescalars(mixedargs...) - parevalf = (passed...) -> f(makeargs(passed...)...) - return (parevalf, passedsrcargstup) - end -end -# Work around losing Type{T}s as DataTypes within the tuple that makeargs creates -@inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Vararg{Any}}) where {T} = - capturescalars((args...)->f(T, args...), Base.tail(mixedargs)) -@inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Ref{Type{S}}, Vararg{Any}}) where {T, S} = - # This definition is identical to the one above and necessary only for - # avoiding method ambiguity. - capturescalars((args...)->f(T, args...), Base.tail(mixedargs)) -@inline capturescalars(f, mixedargs::Tuple{CuSparseVecOrMat, Ref{Type{T}}, Vararg{Any}}) where {T} = - capturescalars((a1, args...)->f(a1, T, args...), (mixedargs[1], Base.tail(Base.tail(mixedargs))...)) -@inline capturescalars(f, mixedargs::Tuple{Union{Ref,AbstractArray{<:Any,0}}, Ref{Type{T}}, Vararg{Any}}) where {T} = - capturescalars((args...)->f(mixedargs[1], T, args...), Base.tail(Base.tail(mixedargs))) - -scalararg(::Number) = true -scalararg(::Any) = false -scalarwrappedarg(::Union{AbstractArray{<:Any,0},Ref}) = true -scalarwrappedarg(::Any) = false - -@inline function _capturescalars() - return (), () -> () -end -@inline function _capturescalars(arg, mixedargs...) - let (rest, f) = _capturescalars(mixedargs...) - if scalararg(arg) - return rest, @inline function(tail...) - (arg, f(tail...)...) - end # add back scalararg after (in makeargs) - elseif scalarwrappedarg(arg) - return rest, @inline function(tail...) - (arg[], f(tail...)...) # TODO: This can put a Type{T} in a tuple - end # unwrap and add back scalararg after (in makeargs) - else - return (arg, rest...), @inline function(head, tail...) - (head, f(tail...)...) - end # pass-through to broadcast - end - end -end -@inline function _capturescalars(arg) - # this definition is just an optimization (to bottom out the recursion slightly sooner) - if scalararg(arg) - return (), () -> (arg,) # add scalararg - elseif scalarwrappedarg(arg) - return (), () -> (arg[],) # unwrap - else - return (arg,), (head,) -> (head,) # pass-through - end -end - -@inline _iszero(x) = x == 0 -@inline _iszero(x::Number) = Base.iszero(x) -@inline _iszero(x::AbstractArray) = Base.iszero(x) -@inline _zeros_eltypes(A) = (zero(eltype(A)),) -@inline _zeros_eltypes(A, Bs...) = (zero(eltype(A)), _zeros_eltypes(Bs...)...) - - -## COV_EXCL_START -## iteration helpers -""" - CSRIterator{Ti}(row, args...) - -A GPU-compatible iterator for accessing the elements of a single row `row` of several CSR -matrices `args` in one go. The row should be in-bounds for every sparse argument. Each -iteration returns a 2-element tuple: The current column, and each arguments' pointer index -(or 0 if that input didn't have an element at that column). The pointers can then be used to -access the elements themselves. - -For convenience, this iterator can be passed non-sparse arguments as well, which will be -ignored (with the returned `col`/`ptr` values set to 0). -""" -struct CSRIterator{Ti,N,ATs} - row::Ti - col_ends::NTuple{N, Ti} - args::ATs -end - -function CSRIterator{Ti}(row, args::Vararg{Any, N}) where {Ti,N} - # check that `row` is valid for all arguments - @boundscheck begin - ntuple(Val(N)) do i - arg = @inbounds args[i] - arg isa CuSparseDeviceMatrixCSR && checkbounds(arg, row, 1) - end - end - - col_ends = ntuple(Val(N)) do i - arg = @inbounds args[i] - if arg isa CuSparseDeviceMatrixCSR - @inbounds(arg.rowPtr[row+1i32]) - else - zero(Ti) - end - end - - CSRIterator{Ti, N, typeof(args)}(row, col_ends, args) -end - -@inline function Base.iterate(iter::CSRIterator{Ti,N}, state=nothing) where {Ti,N} - # helper function to get the column of a sparse array at a specific pointer - @inline function get_col(i, ptr) - arg = @inbounds iter.args[i] - if arg isa CuSparseDeviceMatrixCSR - col_end = @inbounds iter.col_ends[i] - if ptr < col_end - return @inbounds arg.colVal[ptr] % Ti - end - end - typemax(Ti) - end - - # initialize the state - # - ptr: the current index into the colVal/nzVal arrays - # - col: the current column index (cached so that we don't have to re-read each time) - state = something(state, - ntuple(Val(N)) do i - arg = @inbounds iter.args[i] - if arg isa CuSparseDeviceMatrixCSR - ptr = @inbounds iter.args[i].rowPtr[iter.row] % Ti - col = @inbounds get_col(i, ptr) - else - ptr = typemax(Ti) - col = typemax(Ti) - end - (; ptr, col) - end - ) - - # determine the column we're currently processing - cols = ntuple(i -> @inbounds(state[i].col), Val(N)) - cur_col = min(cols...) - cur_col == typemax(Ti) && return - - # fetch the pointers (we don't look up the values, as the caller might want to index - # the sparse array directly, e.g., to mutate it). we don't return `ptrs` from the state - # directly, but first convert the `typemax(Ti)` to a more convenient zero value. - # NOTE: these values may end up unused by the caller (e.g. in the count_nnzs kernels), - # but LLVM appears smart enough to filter them away. - ptrs = ntuple(Val(N)) do i - ptr, col = @inbounds state[i] - col == cur_col ? ptr : zero(Ti) - end - - # advance the state - new_state = ntuple(Val(N)) do i - ptr, col = @inbounds state[i] - if col == cur_col - ptr += one(Ti) - col = get_col(i, ptr) - end - (; ptr, col) - end - - return (cur_col, ptrs), new_state -end - -struct CSCIterator{Ti,N,ATs} - col::Ti - row_ends::NTuple{N, Ti} - args::ATs -end - -function CSCIterator{Ti}(col, args::Vararg{Any, N}) where {Ti,N} - # check that `col` is valid for all arguments - @boundscheck begin - ntuple(Val(N)) do i - arg = @inbounds args[i] - arg isa CuSparseDeviceMatrixCSR && checkbounds(arg, 1, col) - end - end - - row_ends = ntuple(Val(N)) do i - arg = @inbounds args[i] - x = if arg isa CuSparseDeviceMatrixCSC - @inbounds(arg.colPtr[col+1i32]) - else - zero(Ti) - end - x - end - - CSCIterator{Ti, N, typeof(args)}(col, row_ends, args) -end - -@inline function Base.iterate(iter::CSCIterator{Ti,N}, state=nothing) where {Ti,N} - # helper function to get the column of a sparse array at a specific pointer - @inline function get_col(i, ptr) - arg = @inbounds iter.args[i] - if arg isa CuSparseDeviceMatrixCSC - col_end = @inbounds iter.row_ends[i] - if ptr < col_end - return @inbounds arg.rowVal[ptr] % Ti - end - end - typemax(Ti) - end - - # initialize the state - # - ptr: the current index into the rowVal/nzVal arrays - # - row: the current row index (cached so that we don't have to re-read each time) - state = something(state, - ntuple(Val(N)) do i - arg = @inbounds iter.args[i] - if arg isa CuSparseDeviceMatrixCSC - ptr = @inbounds iter.args[i].colPtr[iter.col] % Ti - row = @inbounds get_col(i, ptr) - else - ptr = typemax(Ti) - row = typemax(Ti) - end - (; ptr, row) - end - ) - - # determine the row we're currently processing - rows = ntuple(i -> @inbounds(state[i].row), Val(N)) - cur_row = min(rows...) - cur_row == typemax(Ti) && return - - # fetch the pointers (we don't look up the values, as the caller might want to index - # the sparse array directly, e.g., to mutate it). we don't return `ptrs` from the state - # directly, but first convert the `typemax(Ti)` to a more convenient zero value. - # NOTE: these values may end up unused by the caller (e.g. in the count_nnzs kernels), - # but LLVM appears smart enough to filter them away. - ptrs = ntuple(Val(N)) do i - ptr, row = @inbounds state[i] - row == cur_row ? ptr : zero(Ti) - end - - # advance the state - new_state = ntuple(Val(N)) do i - ptr, row = @inbounds state[i] - if row == cur_row - ptr += one(Ti) - row = get_col(i, ptr) - end - (; ptr, row) - end - - return (cur_row, ptrs), new_state -end - -# helpers to index a sparse or dense array -@inline function _getindex(arg::Union{CuSparseDeviceMatrixCSR{Tv}, - CuSparseDeviceMatrixCSC{Tv}, - CuSparseDeviceVector{Tv}}, I, ptr)::Tv where {Tv} - if ptr == 0 - return zero(Tv) - else - return @inbounds arg.nzVal[ptr]::Tv - end -end - -@inline function _getindex(arg::CuDeviceArray{Tv}, I, ptr)::Tv where {Tv} - return @inbounds arg[I]::Tv -end -@inline _getindex(arg, I, ptr) = Broadcast._broadcast_getindex(arg, I) - -## sparse broadcast implementation - -iter_type(::Type{<:CuSparseMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti} -iter_type(::Type{<:CuSparseMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti} -iter_type(::Type{<:CuSparseDeviceMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti} -iter_type(::Type{<:CuSparseDeviceMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti} - -_has_row(A, offsets, row::Int32, fpreszeros::Bool) = fpreszeros ? 0i32 : row -_has_row(A::CuDeviceArray, offsets, row::Int32, ::Bool) = row -function _has_row(A::CuSparseDeviceVector, offsets, row::Int32, ::Bool)::Int32 - for row_ix in 1i32:length(A.iPtr) - arg_row = @inbounds A.iPtr[row_ix] - arg_row == row && return row_ix - arg_row > row && break - end - return 0i32 -end - -function _get_my_row(first_row)::Int32 - row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x - return row_ix + first_row - 1i32 -end - -function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti, - fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, - args...) where {Ti, N} - row = _get_my_row(first_row) - row > last_row && return - - # TODO load arg.iPtr slices into shared memory - row_is_nnz = 0i32 - arg_row_is_nnz = ntuple(Val(N)) do i - arg = @inbounds args[i] - _has_row(arg, offsets, row, fpreszeros)::Int32 - end - row_is_nnz = 0i32 - for i in 1:N - row_is_nnz |= @inbounds arg_row_is_nnz[i] - end - key = (row_is_nnz == 0i32) ? typemax(Ti) : row - @inbounds offsets[row - first_row + 1i32] = key => arg_row_is_nnz - return -end - -# kernel to count the number of non-zeros in a row, to determine the row offsets -function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}}, - offsets::AbstractVector{Ti}, - args...) where Ti - # every thread processes an entire row - leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x - leading_dim > length(offsets)-1 && return - iter = @inbounds iter_type(T, Ti)(leading_dim, args...) - - # count the nonzero leading_dims of all inputs - accum = zero(Ti) - for (leading_dim, vals) in iter - accum += one(Ti) - end - - # the way we write the nnz counts is a bit strange, but done so that the result - # after accumulation can be directly used as the rowPtr/colPtr array of a CSR/CSC matrix. - @inbounds begin - if leading_dim == 1 - offsets[1] = 1 - end - offsets[leading_dim+1] = accum - end - - return -end - -function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv,Ti}, - offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, - args...) where {Tv, Ti, N, F} - row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x - row_ix > output.nnz && return - row_and_ptrs = @inbounds offsets[row_ix] - row = @inbounds row_and_ptrs[1] - arg_ptrs = @inbounds row_and_ptrs[2] - vals = ntuple(Val(N)) do i - @inline - arg = @inbounds args[i] - # ptr is 0 if the sparse vector doesn't have an element at this row - # ptr is 0 if the arg is a scalar AND f preserves zeros - ptr = @inbounds arg_ptrs[i] - _getindex(arg, row, ptr) - end - output_val = f(vals...) - @inbounds output.iPtr[row_ix] = row - @inbounds output.nzVal[row_ix] = output_val - return -end - -function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{AbstractVector,Nothing}, - args...) where {Ti, T<:Union{CuSparseDeviceMatrixCSR{<:Any,Ti}, - CuSparseDeviceMatrixCSC{<:Any,Ti}}} - # every thread processes an entire row - leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x - leading_dim_size = output isa CuSparseDeviceMatrixCSR ? size(output, 1) : size(output, 2) - leading_dim > leading_dim_size && return - iter = @inbounds iter_type(T, Ti)(leading_dim, args...) - - - output_ptrs = output isa CuSparseDeviceMatrixCSR ? output.rowPtr : output.colPtr - output_ivals = output isa CuSparseDeviceMatrixCSR ? output.colVal : output.rowVal - # fetch the row offset, and write it to the output - @inbounds begin - output_ptr = output_ptrs[leading_dim] = offsets[leading_dim] - if leading_dim == leading_dim_size - output_ptrs[leading_dim+1i32] = offsets[leading_dim+1i32] - end - end - - # set the values for this row - for (sub_leading_dim, ptrs) in iter - index_first = output isa CuSparseDeviceMatrixCSR ? leading_dim : sub_leading_dim - index_second = output isa CuSparseDeviceMatrixCSR ? sub_leading_dim : leading_dim - I = CartesianIndex(index_first, index_second) - vals = ntuple(Val(length(args))) do i - arg = @inbounds args[i] - ptr = @inbounds ptrs[i] - _getindex(arg, I, ptr) - end - - @inbounds output_ivals[output_ptr] = sub_leading_dim - @inbounds output.nzVal[output_ptr] = f(vals...) - output_ptr += one(Ti) - end - - return -end -function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv, Ti}, - CuSparseMatrixCSC{Tv, Ti}}}, f, - output::CuDeviceArray, args...) where {Tv, Ti} - # every thread processes an entire row - leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x - leading_dim_size = T <: CuSparseMatrixCSR ? size(output, 1) : size(output, 2) - leading_dim > leading_dim_size && return - iter = @inbounds iter_type(T, Ti)(leading_dim, args...) - - # set the values for this row - for (sub_leading_dim, ptrs) in iter - index_first = T <: CuSparseMatrixCSR ? leading_dim : sub_leading_dim - index_second = T <: CuSparseMatrixCSR ? sub_leading_dim : leading_dim - I = CartesianIndex(index_first, index_second) - vals = ntuple(Val(length(args))) do i - arg = @inbounds args[i] - ptr = @inbounds ptrs[i] - _getindex(arg, I, ptr) - end - - @inbounds output[I] = f(vals...) - end - - return -end - -function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f::F, - output::CuDeviceArray{Tv}, - offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, - args...) where {Tv, F, N, Ti} - # every thread processes an entire row - row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x - row_ix > length(output) && return - row_and_ptrs = @inbounds offsets[row_ix] - row = @inbounds row_and_ptrs[1] - arg_ptrs = @inbounds row_and_ptrs[2] - vals = ntuple(Val(length(args))) do i - @inline - arg = @inbounds args[i] - # ptr is 0 if the sparse vector doesn't have an element at this row - # ptr is row if the arg is dense OR a scalar with non-zero-preserving f - # ptr is 0 if the arg is a scalar AND f preserves zeros - ptr = @inbounds arg_ptrs[i] - _getindex(arg, row, ptr) - end - @inbounds output[row] = f(vals...) - return -end -## COV_EXCL_STOP - -function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyle}}) - # find the sparse inputs - bc = Broadcast.flatten(bc) - sparse_args = findall(bc.args) do arg - arg isa AbstractCuSparseArray - end - sparse_types = unique(map(i->nameof(typeof(bc.args[i])), sparse_args)) - if length(sparse_types) > 1 - error("broadcast with multiple types of sparse arrays ($(join(sparse_types, ", "))) is not supported") - end - sparse_typ = typeof(bc.args[first(sparse_args)]) - sparse_typ <: Union{CuSparseMatrixCSR,CuSparseMatrixCSC,CuSparseVector} || - error("broadcast with sparse arrays is currently only implemented for vectors and CSR and CSC matrices") - Ti = if sparse_typ <: CuSparseMatrixCSR - reduce(promote_type, map(i->eltype(bc.args[i].rowPtr), sparse_args)) - elseif sparse_typ <: CuSparseMatrixCSC - reduce(promote_type, map(i->eltype(bc.args[i].colPtr), sparse_args)) - elseif sparse_typ <: CuSparseVector - reduce(promote_type, map(i->eltype(bc.args[i].iPtr), sparse_args)) - end - - # determine the output type - Tv = Broadcast.combine_eltypes(bc.f, eltype.(bc.args)) - if !Base.isconcretetype(Tv) - error("""GPU sparse broadcast resulted in non-concrete element type $Tv. - This probably means that the function you are broadcasting contains an error or type instability.""") - end - - # partially-evaluate the function, removing scalars. - parevalf, passedsrcargstup = capturescalars(bc.f, bc.args) - # check if the partially-evaluated function preserves zeros. if so, we'll only need to - # apply it to the sparse input arguments, preserving the sparse structure. - if all(arg->isa(arg, AbstractSparseArray), passedsrcargstup) - fofzeros = parevalf(_zeros_eltypes(passedsrcargstup...)...) - fpreszeros = _iszero(fofzeros) - else - fpreszeros = false - end - - # the kernels below parallelize across rows or cols, not elements, so it's unlikely - # we'll launch many threads. to maximize utilization, parallelize across blocks first. - rows, cols = get(size(bc), 1, 1), get(size(bc), 2, 1) # `size(bc, ::Int)` is missing - function compute_launch_config(kernel) - config = launch_configuration(kernel.fun) - if sparse_typ <: CuSparseMatrixCSR - threads = min(rows, config.threads) - blocks = max(cld(rows, threads), config.blocks) - threads = cld(rows, blocks) - elseif sparse_typ <: CuSparseMatrixCSC - threads = min(cols, config.threads) - blocks = max(cld(cols, threads), config.blocks) - threads = cld(cols, blocks) - elseif sparse_typ <: CuSparseVector - threads = 512 - blocks = max(cld(rows, threads), config.blocks) - end - (; threads, blocks) - end - # for CuSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20 - # but the only rows present in any sparse vector input are between 2 and 128, no need to - # launch massive threads. - # TODO: use the difference here to set the thread count - overall_first_row = one(Ti) - overall_last_row = Ti(rows) - offsets = nothing - # allocate the output container - if !fpreszeros && sparse_typ <: Union{CuSparseMatrixCSR, CuSparseMatrixCSC} - # either we have dense inputs, or the function isn't preserving zeros, - # so use a dense output to broadcast into. - output = CuArray{Tv}(undef, size(bc)) - - # since we'll be iterating the sparse inputs, we need to pre-fill the dense output - # with appropriate values (while setting the sparse inputs to zero). we do this by - # re-using the dense broadcast implementation. - nonsparse_args = map(bc.args) do arg - # NOTE: this assumes the broadcst is flattened, but not yet preprocessed - if arg isa AbstractCuSparseArray - zero(eltype(arg)) - else - arg - end - end - broadcast!(bc.f, output, nonsparse_args...) - elseif length(sparse_args) == 1 && sparse_typ <: Union{CuSparseMatrixCSR, CuSparseMatrixCSC} - # we only have a single sparse input, so we can reuse its structure for the output. - # this avoids a kernel launch and costly synchronization. - sparse_arg = bc.args[first(sparse_args)] - if sparse_typ <: CuSparseMatrixCSR - offsets = rowPtr = sparse_arg.rowPtr - colVal = similar(sparse_arg.colVal) - nzVal = similar(sparse_arg.nzVal, Tv) - output = CuSparseMatrixCSR(rowPtr, colVal, nzVal, size(bc)) - elseif sparse_typ <: CuSparseMatrixCSC - offsets = colPtr = sparse_arg.colPtr - rowVal = similar(sparse_arg.rowVal) - nzVal = similar(sparse_arg.nzVal, Tv) - output = CuSparseMatrixCSC(colPtr, rowVal, nzVal, size(bc)) - end - # NOTE: we don't use CUSPARSE's similar, because that copies the structure arrays, - # while we do that in our kernel (for consistency with other code paths) - else - # determine the number of non-zero elements per row so that we can create an - # appropriately-structured output container - offsets = if sparse_typ <: CuSparseMatrixCSR - CuArray{Ti}(undef, rows+1) - elseif sparse_typ <: CuSparseMatrixCSC - CuArray{Ti}(undef, cols+1) - elseif sparse_typ <: CuSparseVector - CUDA.@allowscalar begin - arg_first_rows = ntuple(Val(length(bc.args))) do i - bc.args[i] isa CuSparseVector && return bc.args[i].iPtr[1] - return one(Ti) - end - arg_last_rows = ntuple(Val(length(bc.args))) do i - bc.args[i] isa CuSparseVector && return bc.args[i].iPtr[end] - return Ti(rows) - end - end - overall_first_row = min(arg_first_rows...) - overall_last_row = max(arg_last_rows...) - CuVector{Pair{Ti, NTuple{length(bc.args), Ti}}}(undef, overall_last_row - overall_first_row + 1) - end - let - args = if sparse_typ <: CuSparseVector - (sparse_typ, overall_first_row, overall_last_row, fpreszeros, offsets, bc.args...) - else - (sparse_typ, offsets, bc.args...) - end - kernel = @cuda launch=false compute_offsets_kernel(args...) - threads, blocks = compute_launch_config(kernel) - kernel(args...; threads, blocks) - end - # accumulate these values so that we can use them directly as row pointer offsets, - # as well as to get the total nnz count to allocate the sparse output array. - # cusparseXcsrgeam2Nnz computes this in one go, but it doesn't seem worth the effort - if !(sparse_typ <: CuSparseVector) - accumulate!(Base.add_sum, offsets, offsets) - total_nnz = @allowscalar last(offsets[end]) - 1 - else - sort!(offsets; by=first) - total_nnz = mapreduce(x->first(x) != typemax(first(x)), +, offsets) - end - output = if sparse_typ <: CuSparseMatrixCSR - colVal = CuArray{Ti}(undef, total_nnz) - nzVal = CuArray{Tv}(undef, total_nnz) - CuSparseMatrixCSR(offsets, colVal, nzVal, size(bc)) - elseif sparse_typ <: CuSparseMatrixCSC - rowVal = CuArray{Ti}(undef, total_nnz) - nzVal = CuArray{Tv}(undef, total_nnz) - CuSparseMatrixCSC(offsets, rowVal, nzVal, size(bc)) - elseif sparse_typ <: CuSparseVector && !fpreszeros - CuArray{Tv}(undef, size(bc)) - elseif sparse_typ <: CuSparseVector && fpreszeros - iPtr = CUDA.zeros(Ti, total_nnz) - nzVal = CUDA.zeros(Tv, total_nnz) - CuSparseVector(iPtr, nzVal, rows) - end - if sparse_typ <: CuSparseVector && !fpreszeros - nonsparse_args = map(bc.args) do arg - # NOTE: this assumes the broadcst is flattened, but not yet preprocessed - if arg isa AbstractCuSparseArray - zero(eltype(arg)) - else - arg - end - end - broadcast!(bc.f, output, nonsparse_args...) - end - end - # perform the actual broadcast - if output isa AbstractCuSparseArray - args = (bc.f, output, offsets, bc.args...) - kernel = @cuda launch=false sparse_to_sparse_broadcast_kernel(args...) - else - args = sparse_typ <: CuSparseVector ? (sparse_typ, bc.f, output, offsets, bc.args...) : - (sparse_typ, bc.f, output, bc.args...) - kernel = @cuda launch=false sparse_to_dense_broadcast_kernel(args...) - end - threads, blocks = compute_launch_config(kernel) - kernel(args...; threads, blocks) - - return output -end diff --git a/lib/cusparse/device.jl b/lib/cusparse/device.jl index 90a7b5e6f6..676c2cb638 100644 --- a/lib/cusparse/device.jl +++ b/lib/cusparse/device.jl @@ -3,44 +3,8 @@ # COV_EXCL_START using SparseArrays -# NOTE: this functionality is currently very bare-bones, only defining the array types -# without any device-compatible sparse array functionality - -# core types - -export CuSparseDeviceVector, CuSparseDeviceMatrixCSC, CuSparseDeviceMatrixCSR, - CuSparseDeviceMatrixBSR, CuSparseDeviceMatrixCOO - -struct CuSparseDeviceVector{Tv,Ti, A} <: AbstractSparseVector{Tv,Ti} - iPtr::CuDeviceVector{Ti, A} - nzVal::CuDeviceVector{Tv, A} - len::Int - nnz::Ti -end - -Base.length(g::CuSparseDeviceVector) = g.len -Base.size(g::CuSparseDeviceVector) = (g.len,) -SparseArrays.nnz(g::CuSparseDeviceVector) = g.nnz - -struct CuSparseDeviceMatrixCSC{Tv,Ti,A} <: SparseArrays.AbstractSparseMatrixCSC{Tv,Ti} - colPtr::CuDeviceVector{Ti, A} - rowVal::CuDeviceVector{Ti, A} - nzVal::CuDeviceVector{Tv, A} - dims::NTuple{2,Int} - nnz::Ti -end - -Base.length(g::CuSparseDeviceMatrixCSC) = prod(g.dims) -Base.size(g::CuSparseDeviceMatrixCSC) = g.dims -SparseArrays.nnz(g::CuSparseDeviceMatrixCSC) = g.nnz -SparseArrays.rowvals(g::CuSparseDeviceMatrixCSC) = g.rowVal -SparseArrays.getcolptr(g::CuSparseDeviceMatrixCSC) = g.colPtr -SparseArrays.getnzval(g::CuSparseDeviceMatrixCSC) = g.nzVal -SparseArrays.nzrange(g::CuSparseDeviceMatrixCSC, col::Integer) = SparseArrays.getcolptr(g)[col]:(SparseArrays.getcolptr(g)[col+1]-1) -SparseArrays.nonzeros(g::CuSparseDeviceMatrixCSC) = g.nzVal - -const CuSparseDeviceColumnView{Tv, Ti} = SubArray{Tv, 1, <:CuSparseDeviceMatrixCSC{Tv, Ti}, Tuple{Base.Slice{Base.OneTo{Int}}, Int}} +const CuSparseDeviceColumnView{Tv, Ti} = SubArray{Tv, 1, <:GPUArrays.GPUSparseDeviceMatrixCSC{Tv, Ti}, Tuple{Base.Slice{Base.OneTo{Int}}, Int}} function SparseArrays.nonzeros(x::CuSparseDeviceColumnView) rowidx, colidx = parentindices(x) A = parent(x) @@ -62,106 +26,4 @@ function SparseArrays.nnz(x::CuSparseDeviceColumnView) return length(SparseArrays.nzrange(A, colidx)) end -struct CuSparseDeviceMatrixCSR{Tv,Ti,A} <: AbstractSparseMatrix{Tv,Ti} - rowPtr::CuDeviceVector{Ti, A} - colVal::CuDeviceVector{Ti, A} - nzVal::CuDeviceVector{Tv, A} - dims::NTuple{2, Int} - nnz::Ti -end - -Base.length(g::CuSparseDeviceMatrixCSR) = prod(g.dims) -Base.size(g::CuSparseDeviceMatrixCSR) = g.dims -SparseArrays.nnz(g::CuSparseDeviceMatrixCSR) = g.nnz -SparseArrays.getnzval(g::CuSparseDeviceMatrixCSR) = g.nzVal - -struct CuSparseDeviceMatrixBSR{Tv,Ti,A} <: AbstractSparseMatrix{Tv,Ti} - rowPtr::CuDeviceVector{Ti, A} - colVal::CuDeviceVector{Ti, A} - nzVal::CuDeviceVector{Tv, A} - dims::NTuple{2,Int} - blockDim::Ti - dir::Char - nnz::Ti -end - -Base.length(g::CuSparseDeviceMatrixBSR) = prod(g.dims) -Base.size(g::CuSparseDeviceMatrixBSR) = g.dims -SparseArrays.nnz(g::CuSparseDeviceMatrixBSR) = g.nnz -SparseArrays.getnzval(g::CuSparseDeviceMatrixBSR) = g.nzVal - -struct CuSparseDeviceMatrixCOO{Tv,Ti,A} <: AbstractSparseMatrix{Tv,Ti} - rowInd::CuDeviceVector{Ti, A} - colInd::CuDeviceVector{Ti, A} - nzVal::CuDeviceVector{Tv, A} - dims::NTuple{2,Int} - nnz::Ti -end - -Base.length(g::CuSparseDeviceMatrixCOO) = prod(g.dims) -Base.size(g::CuSparseDeviceMatrixCOO) = g.dims -SparseArrays.nnz(g::CuSparseDeviceMatrixCOO) = g.nnz -SparseArrays.getnzval(g::CuSparseDeviceMatrixCOO) = g.nzVal - -struct CuSparseDeviceArrayCSR{Tv, Ti, N, M, A} <: AbstractSparseArray{Tv, Ti, N} - rowPtr::CuDeviceArray{Ti, M, A} - colVal::CuDeviceArray{Ti, M, A} - nzVal::CuDeviceArray{Tv, M, A} - dims::NTuple{N, Int} - nnz::Ti -end - -function CuSparseDeviceArrayCSR{Tv, Ti, N, A}(rowPtr::CuArray{<:Integer, M}, colVal::CuArray{<:Integer, M}, nzVal::CuArray{Tv, M}, dims::NTuple{N,<:Integer}) where {Tv, Ti<:Integer, M, N, A} - @assert M == N - 1 "CuSparseDeviceArrayCSR requires ndims(rowPtr) == ndims(colVal) == ndims(nzVal) == length(dims) - 1" - CuSparseDeviceArrayCSR{Tv, Ti, N, M, A}(rowPtr, colVal, nzVal, dims, length(nzVal)) -end - -Base.length(g::CuSparseDeviceArrayCSR) = prod(g.dims) -Base.size(g::CuSparseDeviceArrayCSR) = g.dims -SparseArrays.nnz(g::CuSparseDeviceArrayCSR) = g.nnz -SparseArrays.getnzval(g::CuSparseDeviceArrayCSR) = g.nzVal - -# input/output - -function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceVector) - println(io, "$(length(A))-element device sparse vector at:") - println(io, " iPtr: $(A.iPtr)") - print(io, " nzVal: $(A.nzVal)") -end - -function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceMatrixCSR) - println(io, "$(length(A))-element device sparse matrix CSR at:") - println(io, " rowPtr: $(A.rowPtr)") - println(io, " colVal: $(A.colVal)") - print(io, " nzVal: $(A.nzVal)") -end - -function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceMatrixCSC) - println(io, "$(length(A))-element device sparse matrix CSC at:") - println(io, " colPtr: $(A.colPtr)") - println(io, " rowVal: $(A.rowVal)") - print(io, " nzVal: $(A.nzVal)") -end - -function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceMatrixBSR) - println(io, "$(length(A))-element device sparse matrix BSR at:") - println(io, " rowPtr: $(A.rowPtr)") - println(io, " colVal: $(A.colVal)") - print(io, " nzVal: $(A.nzVal)") -end - -function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceMatrixCOO) - println(io, "$(length(A))-element device sparse matrix COO at:") - println(io, " rowPtr: $(A.rowPtr)") - println(io, " colInd: $(A.colInd)") - print(io, " nzVal: $(A.nzVal)") -end - -function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceArrayCSR) - println(io, "$(length(A))-element device sparse array CSR at:") - println(io, " rowPtr: $(A.rowPtr)") - println(io, " colVal: $(A.colVal)") - print(io, " nzVal: $(A.nzVal)") -end - # COV_EXCL_STOP diff --git a/lib/cusparse/reduce.jl b/lib/cusparse/reduce.jl deleted file mode 100644 index 25d86a29ba..0000000000 --- a/lib/cusparse/reduce.jl +++ /dev/null @@ -1,122 +0,0 @@ -# TODO: implement mapreducedim! - -function Base.mapreduce(f, op, A::Union{CuSparseMatrixCSR,CuSparseMatrixCSC}; - dims=:, init=nothing) - # figure out the destination container type by looking at the initializer element, - # or by relying on inference to reason through the map and reduce functions - if init === nothing - ET = Broadcast.combine_eltypes(f, (A,)) - ET = Base.promote_op(op, ET, ET) - (ET === Union{} || ET === Any) && - error("mapreduce cannot figure the output element type, please pass an explicit init value") - - init = zero(ET) - else - ET = typeof(init) - end - - f_preserves_zeros = ( f(zero(ET)) == zero(ET) ) - # we only handle reducing along one of the two dimensions, - # or a complete reduction (requiring an additional pass) - in(dims, [Colon(), 1, 2]) || error("only dims=:, dims=1 or dims=2 is supported") - - if A isa CuSparseMatrixCSR && dims == 1 - A = CuSparseMatrixCSC(A) - elseif A isa CuSparseMatrixCSC && dims == 2 - A = CuSparseMatrixCSR(A) - end - - m, n = size(A) - if A isa CuSparseMatrixCSR - output = CuArray{ET}(undef, m) - - kernel = @cuda launch=false csr_reduce_kernel(f, op, init, f_preserves_zeros, output, A) - config = launch_configuration(kernel.fun) - threads = min(m, config.threads) - blocks = cld(m, threads) - elseif A isa CuSparseMatrixCSC - output = CuArray{ET}(undef, (1, n)) - - kernel = @cuda launch=false csc_reduce_kernel(f, op, init, f_preserves_zeros, output, A) - config = launch_configuration(kernel.fun) - threads = min(n, config.threads) - blocks = cld(n, threads) - end - kernel(f, op, init, f_preserves_zeros, output, A; threads, blocks) - - if dims == Colon() - mapreduce(identity, op, output; init) - else - output - end -end - -## COV_EXCL_START -function csr_reduce_kernel(f::F, op::OP, neutral, zeros_preserved::Bool, output::CuDeviceArray, args...) where {F, OP} - # every thread processes an entire row - row = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x - row > size(output, 1) && return - iter = @inbounds CSRIterator{Int}(row, args...) - - val = op(neutral, neutral) - - # reduce the values for this row - for (col, ptrs) in iter - I = CartesianIndex(row, col) - vals = ntuple(Val(length(args))) do i - arg = @inbounds args[i] - ptr = @inbounds ptrs[i] - _getindex(arg, I, ptr) - end - val = op(val, f(vals...)) - end - if !zeros_preserved - f_zero_val = f(zero(neutral)) - next_row_ind = row+1i32 - nzs_this_row = ntuple(Val(length(args))) do i - max_n_zeros = size(args[i], 2) - arg_row_ptr = args[i].rowPtr - nz_this_row = max_n_zeros - (@inbounds(arg_row_ptr[next_row_ind]) - @inbounds(arg_row_ptr[row])) - return nz_this_row * f_zero_val - end - val = op(val, nzs_this_row...) - end - - @inbounds output[row] = val - return -end - -function csc_reduce_kernel(f::F, op::OP, neutral, zeros_preserved::Bool, output::CuDeviceArray, args...) where {F, OP} - # every thread processes an entire column - col = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x - col > size(output, 2) && return - iter = @inbounds CSCIterator{Int}(col, args...) - - val = op(neutral, neutral) - - # reduce the values for this col - for (row, ptrs) in iter - I = CartesianIndex(row, col) - vals = ntuple(Val(length(args))) do i - arg = @inbounds args[i] - ptr = @inbounds ptrs[i] - _getindex(arg, I, ptr) - end - val = op(val, f(vals...)) - end - if !zeros_preserved - f_zero_val = f(zero(neutral)) - next_col_ind = col+1i32 - nzs_this_col = ntuple(Val(length(args))) do i - max_n_zeros = size(args[i], 1) - arg_col_ptr = args[i].colPtr - nz_this_col = max_n_zeros - (@inbounds(arg_col_ptr[next_col_ind]) - @inbounds(arg_col_ptr[col])) - return nz_this_col * f_zero_val - end - val = op(val, nzs_this_col...) - end - - @inbounds output[col] = val - return -end -## COV_EXCL_STOP diff --git a/src/CUDAKernels.jl b/src/CUDAKernels.jl index 5a36ed5eaa..5b4fd54fa3 100644 --- a/src/CUDAKernels.jl +++ b/src/CUDAKernels.jl @@ -1,7 +1,7 @@ module CUDAKernels using ..CUDA -using ..CUDA: @device_override, CUSPARSE, default_memory, UnifiedMemory +using ..CUDA: @device_override, CUSPARSE, default_memory, UnifiedMemory, GPUArrays import KernelAbstractions as KA @@ -26,7 +26,9 @@ CUDABackend(; prefer_blocks=false, always_inline=false) = CUDABackend(prefer_blo @inline KA.ones(::CUDABackend, ::Type{T}, dims::Tuple; unified::Bool = false) where T = fill!(CuArray{T, length(dims), unified ? UnifiedMemory : default_memory}(undef, dims), one(T)) KA.get_backend(::CuArray) = CUDABackend() -KA.get_backend(::CUSPARSE.AbstractCuSparseArray) = CUDABackend() +KA.get_backend(::CUSPARSE.CuSparseVector) = CUDABackend() +KA.get_backend(::CUSPARSE.CuSparseMatrixCSC) = CUDABackend() +KA.get_backend(::CUSPARSE.CuSparseMatrixCSR) = CUDABackend() KA.synchronize(::CUDABackend) = synchronize() KA.functional(::CUDABackend) = CUDA.functional() @@ -34,8 +36,8 @@ KA.functional(::CUDABackend) = CUDA.functional() KA.supports_unified(::CUDABackend) = true Adapt.adapt_storage(::CUDABackend, a::AbstractArray) = Adapt.adapt(CuArray, a) -Adapt.adapt_storage(::CUDABackend, a::Union{CuArray,CUSPARSE.AbstractCuSparseArray}) = a -Adapt.adapt_storage(::KA.CPU, a::Union{CuArray,CUSPARSE.AbstractCuSparseArray}) = Adapt.adapt(Array, a) +Adapt.adapt_storage(::CUDABackend, a::Union{CuArray,GPUArrays.AbstractGPUSparseArray}) = a +Adapt.adapt_storage(::KA.CPU, a::Union{CuArray,GPUArrays.AbstractGPUSparseArray}) = Adapt.adapt(Array, a) ## memory operations diff --git a/test/Project.toml b/test/Project.toml index 861e9cc219..05b169e72e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CUDA_Driver_jll = "4ee394cb-3365-5eb0-8335-949819d2adfc" CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -12,8 +13,8 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f" From c3559046bb6ab706edef55982cd9eb643bc8783d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 18 Oct 2025 10:08:17 -0400 Subject: [PATCH 2/5] Restore imports --- lib/cusparse/array.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/cusparse/array.jl b/lib/cusparse/array.jl index 94d32f084f..491ca4c2d6 100644 --- a/lib/cusparse/array.jl +++ b/lib/cusparse/array.jl @@ -9,6 +9,7 @@ export CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixBSR, CuSparseMatrixCO using LinearAlgebra: BlasFloat using SparseArrays +using SparseArrays: nonzeroinds, nonzeros, rowvals, getcolptr abstract type AbstractCuSparseVector{Tv, Ti} <: GPUArrays.AbstractGPUSparseArray{Tv, Ti, 1} end abstract type AbstractCuSparseMatrix{Tv, Ti} <: GPUArrays.AbstractGPUSparseArray{Tv, Ti, 2} end From c473fa7250524a5cbb239b37da402d59d646c1ed Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 21 Oct 2025 07:48:07 -0400 Subject: [PATCH 3/5] More work CUSPARSE tests --- lib/cusparse/CUSPARSE.jl | 3 - lib/cusparse/array.jl | 70 +++++++------- lib/cusparse/conversions.jl | 39 ++++---- lib/cusparse/device.jl | 29 ------ lib/cusparse/interfaces.jl | 14 +-- lib/cusparse/linalg.jl | 128 ------------------------- test/libraries/cusparse/conversions.jl | 12 +-- test/libraries/cusparse/device.jl | 35 +++---- 8 files changed, 92 insertions(+), 238 deletions(-) delete mode 100644 lib/cusparse/device.jl diff --git a/lib/cusparse/CUSPARSE.jl b/lib/cusparse/CUSPARSE.jl index ae1bdf5fb8..c770072d1a 100644 --- a/lib/cusparse/CUSPARSE.jl +++ b/lib/cusparse/CUSPARSE.jl @@ -47,9 +47,6 @@ include("generic.jl") # high-level integrations include("interfaces.jl") -# native functionality -include("device.jl") - include("batched.jl") diff --git a/lib/cusparse/array.jl b/lib/cusparse/array.jl index 491ca4c2d6..631e8ef8a0 100644 --- a/lib/cusparse/array.jl +++ b/lib/cusparse/array.jl @@ -9,7 +9,7 @@ export CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixBSR, CuSparseMatrixCO using LinearAlgebra: BlasFloat using SparseArrays -using SparseArrays: nonzeroinds, nonzeros, rowvals, getcolptr +using SparseArrays: nonzeroinds, nonzeros, rowvals, getcolptr, dimlub abstract type AbstractCuSparseVector{Tv, Ti} <: GPUArrays.AbstractGPUSparseArray{Tv, Ti, 1} end abstract type AbstractCuSparseMatrix{Tv, Ti} <: GPUArrays.AbstractGPUSparseArray{Tv, Ti, 2} end @@ -43,6 +43,7 @@ mutable struct CuSparseMatrixCSC{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSC new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal)) end end +CuSparseMatrixCSC{Tv, Ti}(csc::CuSparseMatrixCSC{Tv, Ti}) where {Tv, Ti} = csc SparseArrays.rowvals(g::T) where {T<:CuSparseVector} = nonzeroinds(g) @@ -84,6 +85,7 @@ mutable struct CuSparseMatrixCSR{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSR end end +CuSparseMatrixCSR{Tv, Ti}(csr::CuSparseMatrixCSR{Tv, Ti}) where {Tv, Ti} = csr CuSparseMatrixCSR(A::CuSparseMatrixCSR) = A function CUDA.unsafe_free!(xs::CuSparseMatrixCSR) @@ -108,6 +110,10 @@ GPUArrays._dense_array_type(::Type{<:CuSparseMatrixCSR}) = CuArray GPUArrays._csc_type(sa::CuSparseMatrixCSR) = CuSparseMatrixCSC GPUArrays._csr_type(sa::CuSparseMatrixCSC) = CuSparseMatrixCSR +GPUArrays._coo_type(sa::Union{CuSparseMatrixCSR, Transpose{<:Any,<:CuSparseMatrixCSR}, Adjoint{<:Any,<:CuSparseMatrixCSR}}) = CuSparseMatrixCOO +GPUArrays._coo_type(sa::Union{CuSparseMatrixCSC, Transpose{<:Any,<:CuSparseMatrixCSC}, Adjoint{<:Any,<:CuSparseMatrixCSC}}) = CuSparseMatrixCOO +GPUArrays._coo_type(::Type{T}) where {T<:Union{CuSparseMatrixCSR, Transpose{<:Any,<:CuSparseMatrixCSR}, Adjoint{<:Any,<:CuSparseMatrixCSR}}} = CuSparseMatrixCOO +GPUArrays._coo_type(::Type{T}) where {T<:Union{CuSparseMatrixCSC, Transpose{<:Any,<:CuSparseMatrixCSC}, Adjoint{<:Any,<:CuSparseMatrixCSC}}} = CuSparseMatrixCOO """ Container to hold sparse matrices in block compressed sparse row (BSR) format on @@ -326,21 +332,6 @@ end ## sparse array interface -function SparseArrays.findnz(S::MT) where {MT <: AbstractCuSparseMatrix} - S2 = CuSparseMatrixCOO(S) - I = S2.rowInd - J = S2.colInd - V = S2.nzVal - - # To make it compatible with the SparseArrays.jl version - idxs = sortperm(J) - I = I[idxs] - J = J[idxs] - V = V[idxs] - - return (I, J, V) -end - function SparseArrays.sparsevec(I::CuArray{Ti}, V::CuArray{Tv}, n::Integer) where {Ti,Tv} CuSparseVector(I, V, n) end @@ -396,16 +387,6 @@ function _cuda_spdiagm_internal(kv::Pair{T,<:CuVector}...) where {T<:Integer} return I, J, V, m, n end -LinearAlgebra.issymmetric(M::Union{CuSparseMatrixCSC,CuSparseMatrixCSR}) = size(M, 1) == size(M, 2) ? norm(M - transpose(M), Inf) == 0 : false -LinearAlgebra.ishermitian(M::Union{CuSparseMatrixCSC,CuSparseMatrixCSR}) = size(M, 1) == size(M, 2) ? norm(M - adjoint(M), Inf) == 0 : false - -LinearAlgebra.istriu(M::UpperTriangular{T,S}) where {T<:BlasFloat, S<:Union{<:AbstractCuSparseMatrix, Adjoint{<:Any, <:AbstractCuSparseMatrix}, Transpose{<:Any, <:AbstractCuSparseMatrix}}} = true -LinearAlgebra.istril(M::UpperTriangular{T,S}) where {T<:BlasFloat, S<:Union{<:AbstractCuSparseMatrix, Adjoint{<:Any, <:AbstractCuSparseMatrix}, Transpose{<:Any, <:AbstractCuSparseMatrix}}} = false -LinearAlgebra.istriu(M::LowerTriangular{T,S}) where {T<:BlasFloat, S<:Union{<:AbstractCuSparseMatrix, Adjoint{<:Any, <:AbstractCuSparseMatrix}, Transpose{<:Any, <:AbstractCuSparseMatrix}}} = false -LinearAlgebra.istril(M::LowerTriangular{T,S}) where {T<:BlasFloat, S<:Union{<:AbstractCuSparseMatrix, Adjoint{<:Any, <:AbstractCuSparseMatrix}, Transpose{<:Any, <:AbstractCuSparseMatrix}}} = true - -Hermitian{T}(Mat::CuSparseMatrix{T}) where {T} = Hermitian{eltype(Mat),typeof(Mat)}(Mat,'U') - SparseArrays.nnz(g::CuSparseMatrixBSR) = g.nnzb * g.blockDim * g.blockDim @@ -550,6 +531,7 @@ CuSparseMatrixCSR{T}(Mat::Adjoint{Tv, <:SparseMatrixCSC}) where {T, Tv} = CuVector{T}(conj.(parent(Mat).nzval)), size(Mat)) CuSparseMatrixCSC{T}(Mat::Union{Transpose{Tv, <:SparseMatrixCSC}, Adjoint{Tv, <:SparseMatrixCSC}}) where {T, Tv} = CuSparseMatrixCSC(CuSparseMatrixCSR{T}(Mat)) CuSparseMatrixCSR{T}(Mat::SparseMatrixCSC) where {T} = CuSparseMatrixCSR(CuSparseMatrixCSC{T}(Mat)) +CuSparseMatrixCSR{Tv, Ti}(Mat::SparseMatrixCSC) where {Tv, Ti} = CuSparseMatrixCSR(CuSparseMatrixCSC{Tv}(Mat)) CuSparseMatrixBSR{T}(Mat::SparseMatrixCSC, blockdim) where {T} = CuSparseMatrixBSR(CuSparseMatrixCSR{T}(Mat), blockdim) CuSparseMatrixCOO{T}(Mat::SparseMatrixCSC) where {T} = CuSparseMatrixCOO(CuSparseMatrixCSR{T}(Mat)) @@ -571,12 +553,12 @@ CuSparseMatrixCSC(x::Adjoint{T}) where {T} = CuSparseMatrixCSC{T}(x) CuSparseMatrixCOO(x::Transpose{T}) where {T} = CuSparseMatrixCOO{T}(x) CuSparseMatrixCOO(x::Adjoint{T}) where {T} = CuSparseMatrixCOO{T}(x) -CuSparseMatrixCSR(x::Transpose{T,<:Union{CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO}}) where {T} = CuSparseMatrixCSR(_sptranspose(parent(x))) -CuSparseMatrixCSC(x::Transpose{T,<:Union{CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO}}) where {T} = CuSparseMatrixCSC(_sptranspose(parent(x))) -CuSparseMatrixCOO(x::Transpose{T,<:Union{CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO}}) where {T} = CuSparseMatrixCOO(_sptranspose(parent(x))) -CuSparseMatrixCSR(x::Adjoint{T,<:Union{CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO}}) where {T} = CuSparseMatrixCSR(_spadjoint(parent(x))) -CuSparseMatrixCSC(x::Adjoint{T,<:Union{CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO}}) where {T} = CuSparseMatrixCSC(_spadjoint(parent(x))) -CuSparseMatrixCOO(x::Adjoint{T,<:Union{CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO}}) where {T} = CuSparseMatrixCOO(_spadjoint(parent(x))) +CuSparseMatrixCSR(x::Transpose{T,<:Union{CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO}}) where {T} = CuSparseMatrixCSR(GPUArrays._sptranspose(parent(x))) +CuSparseMatrixCSC(x::Transpose{T,<:Union{CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO}}) where {T} = CuSparseMatrixCSC(GPUArrays._sptranspose(parent(x))) +CuSparseMatrixCOO(x::Transpose{T,<:Union{CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO}}) where {T} = CuSparseMatrixCOO(GPUArrays._sptranspose(parent(x))) +CuSparseMatrixCSR(x::Adjoint{T,<:Union{CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO}}) where {T} = CuSparseMatrixCSR(GPUArrays._spadjoint(parent(x))) +CuSparseMatrixCSC(x::Adjoint{T,<:Union{CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO}}) where {T} = CuSparseMatrixCSC(GPUArrays._spadjoint(parent(x))) +CuSparseMatrixCOO(x::Adjoint{T,<:Union{CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO}}) where {T} = CuSparseMatrixCOO(GPUArrays._spadjoint(parent(x))) # gpu to cpu SparseArrays.SparseVector(x::CuSparseVector) = SparseVector(length(x), Array(SparseArrays.nonzeroinds(x)), Array(SparseArrays.nonzeros(x))) @@ -592,6 +574,12 @@ Base.collect(x::CuSparseMatrixCSR) = collect(SparseMatrixCSC(x)) Base.collect(x::CuSparseMatrixBSR) = collect(SparseMatrixCSC(x)) Base.collect(x::CuSparseMatrixCOO) = collect(SparseMatrixCSC(x)) +Base.Array(x::CuSparseVector) = collect(SparseVector(x)) +Base.Array(x::CuSparseMatrixCSC) = collect(SparseMatrixCSC(x)) +Base.Array(x::CuSparseMatrixCSR) = collect(SparseMatrixCSC(x)) +Base.Array(x::CuSparseMatrixBSR) = collect(SparseMatrixCSC(x)) +Base.Array(x::CuSparseMatrixCOO) = collect(SparseMatrixCSC(x)) + Adapt.adapt_storage(::Type{CuArray}, xs::SparseVector) = CuSparseVector(xs) Adapt.adapt_storage(::Type{CuArray}, xs::SparseMatrixCSC) = CuSparseMatrixCSC(xs) Adapt.adapt_storage(::Type{CuArray{T}}, xs::SparseVector) where {T} = CuSparseVector{T}(xs) @@ -787,6 +775,16 @@ function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixCSC) ) end +function GPUArrays.GPUSparseDeviceMatrixBSR(rowPtr::CuDeviceVector{Ti, A}, + colVal::CuDeviceVector{Ti, A}, + nzVal::CuDeviceVector{Tv, A}, + dims::NTuple{2, Int}, + blockDim::Ti, + dir::Char, + nnz::Ti) where {Ti, Tv, A} + GPUArrays.GPUSparseDeviceMatrixBSR{Tv, Ti, CuDeviceVector{Ti, A}, CuDeviceVector{Tv, A}, A}(rowPtr, colVal, nzVal, dims, blockDim, dir, nnz) +end + function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixBSR) return GPUArrays.GPUSparseDeviceMatrixBSR( adapt(to, x.rowPtr), @@ -797,6 +795,14 @@ function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixBSR) ) end +function GPUArrays.GPUSparseDeviceMatrixCOO(rowInd::CuDeviceVector{Ti, A}, + colInd::CuDeviceVector{Ti, A}, + nzVal::CuDeviceVector{Tv, A}, + dims::NTuple{2, Int}, + nnz::Ti) where {Ti, Tv, A} + GPUArrays.GPUSparseDeviceMatrixCOO{Tv, Ti, CuDeviceVector{Ti, A}, CuDeviceVector{Tv, A}, A}(rowInd, colInd, nzVal, dims, nnz) +end + function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixCOO) return GPUArrays.GPUSparseDeviceMatrixCOO( adapt(to, x.rowInd), diff --git a/lib/cusparse/conversions.jl b/lib/cusparse/conversions.jl index e76ee6e801..afa751274d 100644 --- a/lib/cusparse/conversions.jl +++ b/lib/cusparse/conversions.jl @@ -1,8 +1,8 @@ export sort_csc, sort_csr, sort_coo adjtrans_wrappers = ((identity, identity), - (M -> :(Transpose{T, <:$M}), M -> :(_sptranspose(parent($M)))), - (M -> :(Adjoint{T, <:$M}), M -> :(_spadjoint(parent($M))))) + (M -> :(Transpose{T, <:$M}), M -> :(GPUArrays._sptranspose(parent($M)))), + (M -> :(Adjoint{T, <:$M}), M -> :(GPUArrays._spadjoint(parent($M))))) # conversion routines between different sparse and dense storage formats @@ -330,7 +330,7 @@ end # by flipping rows and columns, we can use that to get CSC to CSR too for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64) @eval begin - function CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty}; index::SparseChar='O', action::cusparseAction_t=CUSPARSE_ACTION_NUMERIC, algo::cusparseCsr2CscAlg_t=CUSPARSE_CSR2CSC_ALG1) + function CuSparseMatrixCSC{$elty, Cint}(csr::CuSparseMatrixCSR{$elty, Cint}; index::SparseChar='O', action::cusparseAction_t=CUSPARSE_ACTION_NUMERIC, algo::cusparseCsr2CscAlg_t=CUSPARSE_CSR2CSC_ALG1) m,n = size(csr) colPtr = CUDA.zeros(Cint, n+1) rowVal = CUDA.zeros(Cint, nnz(csr)) @@ -349,8 +349,9 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64) end CuSparseMatrixCSC(colPtr,rowVal,nzVal,size(csr)) end - - function CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty}; index::SparseChar='O', action::cusparseAction_t=CUSPARSE_ACTION_NUMERIC, algo::cusparseCsr2CscAlg_t=CUSPARSE_CSR2CSC_ALG1) + CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty, Cint}; index::SparseChar='O', action::cusparseAction_t=CUSPARSE_ACTION_NUMERIC, algo::cusparseCsr2CscAlg_t=CUSPARSE_CSR2CSC_ALG1) = + CuSparseMatrixCSC{$elty, Cint}(csr; index=index, action=action, algo=algo) + function CuSparseMatrixCSR{$elty, Cint}(csc::CuSparseMatrixCSC{$elty, Cint}; index::SparseChar='O', action::cusparseAction_t=CUSPARSE_ACTION_NUMERIC, algo::cusparseCsr2CscAlg_t=CUSPARSE_CSR2CSC_ALG1) m,n = size(csc) rowPtr = CUDA.zeros(Cint,m+1) colVal = CUDA.zeros(Cint,nnz(csc)) @@ -369,6 +370,8 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64) end CuSparseMatrixCSR(rowPtr,colVal,nzVal,size(csc)) end + CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty, Cint}; index::SparseChar='O', action::cusparseAction_t=CUSPARSE_ACTION_NUMERIC, algo::cusparseCsr2CscAlg_t=CUSPARSE_CSR2CSC_ALG1) = + CuSparseMatrixCSR{$elty, Cint}(csc; index=index, action=action, algo=algo) end end @@ -585,42 +588,46 @@ end ## CSR to COO and vice-versa -function CuSparseMatrixCSR{Tv}(coo::CuSparseMatrixCOO{Tv}; index::SparseChar='O') where {Tv} +function CuSparseMatrixCSR{Tv,Cint}(coo::CuSparseMatrixCOO{Tv}; index::SparseChar='O') where {Tv} m,n = size(coo) - nnz(coo) == 0 && return CuSparseMatrixCSR{Tv}(CUDA.ones(Cint, m+1), coo.colInd, nonzeros(coo), size(coo)) + nnz(coo) == 0 && return CuSparseMatrixCSR{Tv,Cint}(CUDA.ones(Cint, m+1), coo.colInd, nonzeros(coo), size(coo)) coo = sort_coo(coo, 'R') csrRowPtr = CuVector{Cint}(undef, m+1) cusparseXcoo2csr(handle(), coo.rowInd, nnz(coo), m, csrRowPtr, index) - CuSparseMatrixCSR{Tv}(csrRowPtr, coo.colInd, nonzeros(coo), size(coo)) + CuSparseMatrixCSR{Tv,Cint}(csrRowPtr, coo.colInd, nonzeros(coo), size(coo)) end +CuSparseMatrixCSR{Tv}(coo::CuSparseMatrixCOO{Tv}; index::SparseChar='O') where {Tv} = CuSparseMatrixCSR{Tv, Cint}(coo) -function CuSparseMatrixCOO{Tv}(csr::CuSparseMatrixCSR{Tv}; index::SparseChar='O') where {Tv} +function CuSparseMatrixCOO{Tv,Cint}(csr::CuSparseMatrixCSR{Tv}; index::SparseChar='O') where {Tv} m,n = size(csr) - nnz(csr) == 0 && return CuSparseMatrixCOO{Tv}(CUDA.zeros(Cint, 0), CUDA.zeros(Cint, 0), nonzeros(csr), size(csr)) + nnz(csr) == 0 && return CuSparseMatrixCOO{Tv,Cint}(CUDA.zeros(Cint, 0), CUDA.zeros(Cint, 0), nonzeros(csr), size(csr)) cooRowInd = CuVector{Cint}(undef, nnz(csr)) cusparseXcsr2coo(handle(), csr.rowPtr, nnz(csr), m, cooRowInd, index) - CuSparseMatrixCOO{Tv}(cooRowInd, csr.colVal, nonzeros(csr), size(csr)) + CuSparseMatrixCOO{Tv,Cint}(cooRowInd, csr.colVal, nonzeros(csr), size(csr)) end +CuSparseMatrixCOO{Tv}(csr::CuSparseMatrixCSR{Tv}; index::SparseChar='O') where {Tv} = CuSparseMatrixCOO{Tv, Cint}(csr) ### CSC to COO and viceversa -function CuSparseMatrixCSC{Tv}(coo::CuSparseMatrixCOO{Tv}; index::SparseChar='O') where {Tv} +function CuSparseMatrixCSC{Tv,Cint}(coo::CuSparseMatrixCOO{Tv}; index::SparseChar='O') where {Tv} m,n = size(coo) nnz(coo) == 0 && return CuSparseMatrixCSC{Tv}(CUDA.ones(Cint, n+1), coo.rowInd, nonzeros(coo), size(coo)) coo = sort_coo(coo, 'C') cscColPtr = CuVector{Cint}(undef, n+1) cusparseXcoo2csr(handle(), coo.colInd, nnz(coo), n, cscColPtr, index) - CuSparseMatrixCSC{Tv}(cscColPtr, coo.rowInd, nonzeros(coo), size(coo)) + CuSparseMatrixCSC{Tv,Cint}(cscColPtr, coo.rowInd, nonzeros(coo), size(coo)) end +CuSparseMatrixCSC{Tv}(coo::CuSparseMatrixCOO{Tv}; index::SparseChar='O') where {Tv} = CuSparseMatrixCSC{Tv, Cint}(coo) -function CuSparseMatrixCOO{Tv}(csc::CuSparseMatrixCSC{Tv}; index::SparseChar='O') where {Tv} +function CuSparseMatrixCOO{Tv,Cint}(csc::CuSparseMatrixCSC{Tv}; index::SparseChar='O') where {Tv} m,n = size(csc) - nnz(csc) == 0 && return CuSparseMatrixCOO{Tv}(CUDA.zeros(Cint, 0), CUDA.zeros(Cint, 0), nonzeros(csc), size(csc)) + nnz(csc) == 0 && return CuSparseMatrixCOO{Tv,Cint}(CUDA.zeros(Cint, 0), CUDA.zeros(Cint, 0), nonzeros(csc), size(csc)) cooColInd = CuVector{Cint}(undef, nnz(csc)) cusparseXcsr2coo(handle(), csc.colPtr, nnz(csc), n, cooColInd, index) - coo = CuSparseMatrixCOO{Tv}(csc.rowVal, cooColInd, nonzeros(csc), size(csc)) + coo = CuSparseMatrixCOO{Tv,Cint}(csc.rowVal, cooColInd, nonzeros(csc), size(csc)) coo = sort_coo(coo, 'R') end +CuSparseMatrixCOO{Tv}(csc::CuSparseMatrixCSC{Tv}; index::SparseChar='O') where {Tv} = CuSparseMatrixCOO{Tv, Cint}(csc) ### BSR to COO and vice-versa diff --git a/lib/cusparse/device.jl b/lib/cusparse/device.jl deleted file mode 100644 index 676c2cb638..0000000000 --- a/lib/cusparse/device.jl +++ /dev/null @@ -1,29 +0,0 @@ -# on-device sparse array functionality -# should be excluded from coverage counts -# COV_EXCL_START -using SparseArrays - - -const CuSparseDeviceColumnView{Tv, Ti} = SubArray{Tv, 1, <:GPUArrays.GPUSparseDeviceMatrixCSC{Tv, Ti}, Tuple{Base.Slice{Base.OneTo{Int}}, Int}} -function SparseArrays.nonzeros(x::CuSparseDeviceColumnView) - rowidx, colidx = parentindices(x) - A = parent(x) - @inbounds y = view(SparseArrays.nonzeros(A), SparseArrays.nzrange(A, colidx)) - return y -end - -function SparseArrays.nonzeroinds(x::CuSparseDeviceColumnView) - rowidx, colidx = parentindices(x) - A = parent(x) - @inbounds y = view(SparseArrays.rowvals(A), SparseArrays.nzrange(A, colidx)) - return y -end -SparseArrays.rowvals(x::CuSparseDeviceColumnView) = SparseArrays.nonzeroinds(x) - -function SparseArrays.nnz(x::CuSparseDeviceColumnView) - rowidx, colidx = parentindices(x) - A = parent(x) - return length(SparseArrays.nzrange(A, colidx)) -end - -# COV_EXCL_STOP diff --git a/lib/cusparse/interfaces.jl b/lib/cusparse/interfaces.jl index f6659d7b55..d764410b3a 100644 --- a/lib/cusparse/interfaces.jl +++ b/lib/cusparse/interfaces.jl @@ -2,29 +2,29 @@ using LinearAlgebra using LinearAlgebra: BlasComplex, BlasFloat, BlasReal, MulAddMul, AdjOrTrans -export _spadjoint, _sptranspose +using GPUArrays: _spadjoint, _sptranspose -function _spadjoint(A::CuSparseMatrixCSR) +function GPUArrays._spadjoint(A::CuSparseMatrixCSR) Aᴴ = CuSparseMatrixCSC(A.rowPtr, A.colVal, conj(A.nzVal), reverse(size(A))) CuSparseMatrixCSR(Aᴴ) end -function _sptranspose(A::CuSparseMatrixCSR) +function GPUArrays._sptranspose(A::CuSparseMatrixCSR) Aᵀ = CuSparseMatrixCSC(A.rowPtr, A.colVal, A.nzVal, reverse(size(A))) CuSparseMatrixCSR(Aᵀ) end -function _spadjoint(A::CuSparseMatrixCSC) +function GPUArrays._spadjoint(A::CuSparseMatrixCSC) Aᴴ = CuSparseMatrixCSR(A.colPtr, A.rowVal, conj(A.nzVal), reverse(size(A))) CuSparseMatrixCSC(Aᴴ) end -function _sptranspose(A::CuSparseMatrixCSC) +function GPUArrays._sptranspose(A::CuSparseMatrixCSC) Aᵀ = CuSparseMatrixCSR(A.colPtr, A.rowVal, A.nzVal, reverse(size(A))) CuSparseMatrixCSC(Aᵀ) end -function _spadjoint(A::CuSparseMatrixCOO) +function GPUArrays._spadjoint(A::CuSparseMatrixCOO) # we use sparse instead of CuSparseMatrixCOO because we want to sort the matrix. sparse(A.colInd, A.rowInd, conj(A.nzVal), reverse(size(A))..., fmt = :coo) end -function _sptranspose(A::CuSparseMatrixCOO) +function GPUArrays._sptranspose(A::CuSparseMatrixCOO) # we use sparse instead of CuSparseMatrixCOO because we want to sort the matrix. sparse(A.colInd, A.rowInd, A.nzVal, reverse(size(A))..., fmt = :coo) end diff --git a/lib/cusparse/linalg.jl b/lib/cusparse/linalg.jl index 4be578e284..ebcde249be 100644 --- a/lib/cusparse/linalg.jl +++ b/lib/cusparse/linalg.jl @@ -1,30 +1,6 @@ using LinearAlgebra using LinearAlgebra: BlasComplex, BlasFloat, BlasReal, BlasInt -function LinearAlgebra.opnorm(A::CuSparseMatrixCSR, p::Real=2) - if p == Inf - return maximum(sum(abs, A; dims=2)) - elseif p == 1 - return maximum(sum(abs, A; dims=1)) - else - throw(ArgumentError("p=$p is not supported")) - end -end - -LinearAlgebra.opnorm(A::CuSparseMatrixCSC, p::Real=2) = opnorm(CuSparseMatrixCSR(A), p) - -function LinearAlgebra.norm(A::AbstractCuSparseMatrix{T}, p::Real=2) where T - if p == Inf - return maximum(abs.(A.nzVal)) - elseif p == -Inf - return minimum(abs.(A.nzVal)) - elseif p == 0 - return Float64(A.nnz) - else - return sum(abs.(A.nzVal).^p)^(1/p) - end -end - function LinearAlgebra.triu(A::CuSparseMatrixCOO, k::Integer=0) mask = A.rowInd .+ k .<= A.colInd rows = A.rowInd[mask] @@ -242,107 +218,3 @@ function LinearAlgebra.dot(y::CuVector{T}, A::CuSparseMatrixCSR{T}, x::CuVector{ return sum(result) end -# work around upstream breakage from JuliaLang/julia#55547 -@static if VERSION >= v"1.11.2" - const CuSparseUpperOrUnitUpperTriangular = LinearAlgebra.UpperOrUnitUpperTriangular{ - <:Any,<:Union{<:AbstractCuSparseMatrix, Adjoint{<:Any, <:AbstractCuSparseMatrix}, Transpose{<:Any, <:AbstractCuSparseMatrix}}} - const CuSparseLowerOrUnitLowerTriangular = LinearAlgebra.LowerOrUnitLowerTriangular{ - <:Any,<:Union{<:AbstractCuSparseMatrix, Adjoint{<:Any, <:AbstractCuSparseMatrix}, Transpose{<:Any, <:AbstractCuSparseMatrix}}} - LinearAlgebra.istriu(::CuSparseUpperOrUnitUpperTriangular) = true - LinearAlgebra.istril(::CuSparseUpperOrUnitUpperTriangular) = false - LinearAlgebra.istriu(::CuSparseLowerOrUnitLowerTriangular) = false - LinearAlgebra.istril(::CuSparseLowerOrUnitLowerTriangular) = true -end - -for SparseMatrixType in [:CuSparseMatrixCSC, :CuSparseMatrixCSR] - @eval begin - LinearAlgebra.triu(A::$SparseMatrixType{T}, k::Integer) where {T} = - $SparseMatrixType( triu(CuSparseMatrixCOO(A), k) ) - LinearAlgebra.triu(A::Transpose{T,<:$SparseMatrixType}, k::Integer) where {T} = - $SparseMatrixType( triu(CuSparseMatrixCOO(_sptranspose(parent(A))), k) ) - LinearAlgebra.triu(A::Adjoint{T,<:$SparseMatrixType}, k::Integer) where {T} = - $SparseMatrixType( triu(CuSparseMatrixCOO(_spadjoint(parent(A))), k) ) - - LinearAlgebra.tril(A::$SparseMatrixType{T}, k::Integer) where {T} = - $SparseMatrixType( tril(CuSparseMatrixCOO(A), k) ) - LinearAlgebra.tril(A::Transpose{T,<:$SparseMatrixType}, k::Integer) where {T} = - $SparseMatrixType( tril(CuSparseMatrixCOO(_sptranspose(parent(A))), k) ) - LinearAlgebra.tril(A::Adjoint{T,<:$SparseMatrixType}, k::Integer) where {T} = - $SparseMatrixType( tril(CuSparseMatrixCOO(_spadjoint(parent(A))), k) ) - - LinearAlgebra.triu(A::Union{$SparseMatrixType{T}, Transpose{T,<:$SparseMatrixType}, Adjoint{T,<:$SparseMatrixType}}) where {T} = - $SparseMatrixType( triu(CuSparseMatrixCOO(A), 0) ) - LinearAlgebra.tril(A::Union{$SparseMatrixType{T},Transpose{T,<:$SparseMatrixType}, Adjoint{T,<:$SparseMatrixType}}) where {T} = - $SparseMatrixType( tril(CuSparseMatrixCOO(A), 0) ) - - LinearAlgebra.kron(A::$SparseMatrixType{T}, B::$SparseMatrixType{T}) where {T} = - $SparseMatrixType( kron(CuSparseMatrixCOO(A), CuSparseMatrixCOO(B)) ) - LinearAlgebra.kron(A::$SparseMatrixType{T}, B::Diagonal) where {T} = - $SparseMatrixType( kron(CuSparseMatrixCOO(A), B) ) - LinearAlgebra.kron(A::Diagonal, B::$SparseMatrixType{T}) where {T} = - $SparseMatrixType( kron(A, CuSparseMatrixCOO(B)) ) - - LinearAlgebra.kron(A::Transpose{T,<:$SparseMatrixType}, B::$SparseMatrixType{T}) where {T} = - $SparseMatrixType( kron(CuSparseMatrixCOO(_sptranspose(parent(A))), CuSparseMatrixCOO(B)) ) - LinearAlgebra.kron(A::$SparseMatrixType{T}, B::Transpose{T,<:$SparseMatrixType}) where {T} = - $SparseMatrixType( kron(CuSparseMatrixCOO(A), CuSparseMatrixCOO(_sptranspose(parent(B)))) ) - LinearAlgebra.kron(A::Transpose{T,<:$SparseMatrixType}, B::Transpose{T,<:$SparseMatrixType}) where {T} = - $SparseMatrixType( kron(CuSparseMatrixCOO(_sptranspose(parent(A))), CuSparseMatrixCOO(_sptranspose(parent(B)))) ) - LinearAlgebra.kron(A::Transpose{T,<:$SparseMatrixType}, B::Diagonal) where {T} = - $SparseMatrixType( kron(CuSparseMatrixCOO(_sptranspose(parent(A))), B) ) - LinearAlgebra.kron(A::Diagonal, B::Transpose{T,<:$SparseMatrixType}) where {T} = - $SparseMatrixType( kron(A, CuSparseMatrixCOO(_sptranspose(parent(B)))) ) - - LinearAlgebra.kron(A::Adjoint{T,<:$SparseMatrixType}, B::$SparseMatrixType{T}) where {T} = - $SparseMatrixType( kron(CuSparseMatrixCOO(_spadjoint(parent(A))), CuSparseMatrixCOO(B)) ) - LinearAlgebra.kron(A::$SparseMatrixType{T}, B::Adjoint{T,<:$SparseMatrixType}) where {T} = - $SparseMatrixType( kron(CuSparseMatrixCOO(A), CuSparseMatrixCOO(_spadjoint(parent(B)))) ) - LinearAlgebra.kron(A::Adjoint{T,<:$SparseMatrixType}, B::Adjoint{T,<:$SparseMatrixType}) where {T} = - $SparseMatrixType( kron(CuSparseMatrixCOO(_spadjoint(parent(A))), CuSparseMatrixCOO(_spadjoint(parent(B)))) ) - LinearAlgebra.kron(A::Adjoint{T,<:$SparseMatrixType}, B::Diagonal) where {T} = - $SparseMatrixType( kron(CuSparseMatrixCOO(_spadjoint(parent(A))), B) ) - LinearAlgebra.kron(A::Diagonal, B::Adjoint{T,<:$SparseMatrixType}) where {T} = - $SparseMatrixType( kron(A, CuSparseMatrixCOO(_spadjoint(parent(B)))) ) - - - function Base.reshape(A::$SparseMatrixType, dims::Dims) - B = CuSparseMatrixCOO(A) - $SparseMatrixType(reshape(B, dims)) - end - - function SparseArrays.droptol!(A::$SparseMatrixType, tol::Real) - B = CuSparseMatrixCOO(A) - droptol!(B, tol) - copyto!(A, $SparseMatrixType(B)) - end - - function LinearAlgebra.exp(A::$SparseMatrixType; threshold = 1e-7, nonzero_tol = 1e-14) - rows = LinearAlgebra.checksquare(A) # Throws exception if not square - typeA = eltype(A) - - mat_norm = norm(A, Inf) - scaling_factor = nextpow(2, mat_norm) # Native routine, faster - A = A ./ scaling_factor - delta = 1 - - P = $SparseMatrixType(spdiagm(0 => ones(eltype(A), rows))) - next_term = P - n = 1 - - while delta > threshold - next_term = typeA(1 / n) * A * next_term - droptol!(next_term, nonzero_tol) - delta = norm(next_term, Inf) - copyto!(P, P + next_term) - n = n + 1 - end - for n = 1:log2(scaling_factor) - P = P * P; - if nnz(P) / length(P) < 0.25 - droptol!(P, nonzero_tol) - end - end - P - end - end -end diff --git a/test/libraries/cusparse/conversions.jl b/test/libraries/cusparse/conversions.jl index 870d591fa0..4e6c9700c6 100644 --- a/test/libraries/cusparse/conversions.jl +++ b/test/libraries/cusparse/conversions.jl @@ -251,32 +251,32 @@ if !(v"12.0" <= CUSPARSE.version() < v"12.1") vals_Z = [5, 6, 7] |> cu csr_Z = CuSparseMatrixCSR{Float64}(crows_Z, cols_Z, vals_Z, (2,3)) - csr_to_csc_Z = CuSparseMatrixCSC{Float64}(csr_Z, index='Z') + csr_to_csc_Z = CuSparseMatrixCSC{Float64,Cint}(csr_Z, index='Z') @test csr_to_csc_Z.colPtr ≈ csc_Z.colPtr @test csr_to_csc_Z.rowVal ≈ csc_Z.rowVal @test csr_to_csc_Z.nzVal ≈ csc_Z.nzVal - csc_to_csr_Z = CuSparseMatrixCSR{Float64}(csc_Z, index='Z') + csc_to_csr_Z = CuSparseMatrixCSR{Float64,Cint}(csc_Z, index='Z') @test csc_to_csr_Z.rowPtr ≈ csr_Z.rowPtr @test csc_to_csr_Z.colVal ≈ csr_Z.colVal @test csc_to_csr_Z.nzVal ≈ csr_Z.nzVal - csr_to_coo_Z = CuSparseMatrixCOO{Float64}(csr_Z, index='Z') + csr_to_coo_Z = CuSparseMatrixCOO{Float64,Cint}(csr_Z, index='Z') @test csr_to_coo_Z.rowInd ≈ coo_Z.rowInd @test csr_to_coo_Z.colInd ≈ coo_Z.colInd @test csr_to_coo_Z.nzVal ≈ coo_Z.nzVal - coo_to_csr_Z = CuSparseMatrixCSR{Float64}(coo_Z, index='Z') + coo_to_csr_Z = CuSparseMatrixCSR{Float64,Cint}(coo_Z, index='Z') @test coo_to_csr_Z.rowPtr ≈ csr_Z.rowPtr @test coo_to_csr_Z.colVal ≈ csr_Z.colVal @test coo_to_csr_Z.nzVal ≈ csr_Z.nzVal - csc_to_coo_Z = CuSparseMatrixCOO{Float64}(csc_Z, index='Z') + csc_to_coo_Z = CuSparseMatrixCOO{Float64,Cint}(csc_Z, index='Z') @test csc_to_coo_Z.rowInd ≈ coo_Z.rowInd @test csc_to_coo_Z.colInd ≈ coo_Z.colInd @test csc_to_coo_Z.nzVal ≈ coo_Z.nzVal - coo_to_csc_Z = CuSparseMatrixCSC{Float64}(coo_Z, index='Z') + coo_to_csc_Z = CuSparseMatrixCSC{Float64,Cint}(coo_Z, index='Z') @test coo_to_csc_Z.colPtr ≈ csc_Z.colPtr @test coo_to_csc_Z.rowVal ≈ csc_Z.rowVal @test coo_to_csc_Z.nzVal ≈ csc_Z.nzVal diff --git a/test/libraries/cusparse/device.jl b/test/libraries/cusparse/device.jl index 650a065e48..164d5835b8 100644 --- a/test/libraries/cusparse/device.jl +++ b/test/libraries/cusparse/device.jl @@ -1,31 +1,32 @@ using CUDA.CUSPARSE using SparseArrays -using CUDA.CUSPARSE: CuSparseDeviceVector, CuSparseDeviceMatrixCSC, CuSparseDeviceMatrixCSR, - CuSparseDeviceMatrixBSR, CuSparseDeviceMatrixCOO +using SparseArrays: nonzeros, nnz, rowvals +using CUDA.GPUArrays: GPUSparseDeviceVector, GPUSparseDeviceMatrixCSC, GPUSparseDeviceMatrixCSR, + GPUSparseDeviceMatrixBSR, GPUSparseDeviceMatrixCOO @testset "cudaconvert" begin - @test isbitstype(CuSparseDeviceVector{Float32, Cint, AS.Global}) - @test isbitstype(CuSparseDeviceMatrixCSC{Float32, Cint, AS.Global}) - @test isbitstype(CuSparseDeviceMatrixCSR{Float32, Cint, AS.Global}) - @test isbitstype(CuSparseDeviceMatrixBSR{Float32, Cint, AS.Global}) - @test isbitstype(CuSparseDeviceMatrixCOO{Float32, Cint, AS.Global}) + @test isbitstype(GPUSparseDeviceVector{Float32, Cint, CuDeviceVector{Cint, AS.Global}, CuDeviceVector{Float32, AS.Global}, AS.Global}) + @test isbitstype(GPUSparseDeviceMatrixCSC{Float32, Cint, CuDeviceVector{Cint, AS.Global}, CuDeviceVector{Float32, AS.Global}, AS.Global}) + @test isbitstype(GPUSparseDeviceMatrixCSR{Float32, Cint, CuDeviceVector{Cint, AS.Global}, CuDeviceVector{Float32, AS.Global}, AS.Global}) + @test isbitstype(GPUSparseDeviceMatrixBSR{Float32, Cint, CuDeviceVector{Cint, AS.Global}, CuDeviceVector{Float32, AS.Global}, AS.Global}) + @test isbitstype(GPUSparseDeviceMatrixCOO{Float32, Cint, CuDeviceVector{Cint, AS.Global}, CuDeviceVector{Float32, AS.Global}, AS.Global}) V = sprand(10, 0.5) cuV = CuSparseVector(V) - @test cudaconvert(cuV) isa CuSparseDeviceVector{Float64, Cint, AS.Global} + @test cudaconvert(cuV) isa GPUSparseDeviceVector{Float64, Cint, CuDeviceVector{Cint, AS.Global}, CuDeviceVector{Float64, AS.Global}, AS.Global} A = sprand(10, 10, 0.5) cuA = CuSparseMatrixCSC(A) - @test cudaconvert(cuA) isa CuSparseDeviceMatrixCSC{Float64, Cint, AS.Global} + @test cudaconvert(cuA) isa GPUSparseDeviceMatrixCSC{Float64, Cint, CuDeviceVector{Cint, AS.Global}, CuDeviceVector{Float64, AS.Global}, AS.Global} cuA = CuSparseMatrixCSR(A) - @test cudaconvert(cuA) isa CuSparseDeviceMatrixCSR{Float64, Cint, AS.Global} + @test cudaconvert(cuA) isa GPUSparseDeviceMatrixCSR{Float64, Cint, CuDeviceVector{Cint, AS.Global}, CuDeviceVector{Float64, AS.Global}, AS.Global} cuA = CuSparseMatrixCOO(A) - @test cudaconvert(cuA) isa CuSparseDeviceMatrixCOO{Float64, Cint, AS.Global} + @test cudaconvert(cuA) isa GPUSparseDeviceMatrixCOO{Float64, Cint, CuDeviceVector{Cint, AS.Global}, CuDeviceVector{Float64, AS.Global}, AS.Global} cuA = CuSparseMatrixBSR(A, 2) - @test cudaconvert(cuA) isa CuSparseDeviceMatrixBSR{Float64, Cint, AS.Global} + @test cudaconvert(cuA) isa GPUSparseDeviceMatrixBSR{Float64, Cint, CuDeviceVector{Cint, AS.Global}, CuDeviceVector{Float64, AS.Global}, AS.Global} end @testset "device SparseArrays api" begin @@ -34,7 +35,7 @@ end function nnz_per_column_kernel(out, A) i = (blockIdx().x - 1) * blockDim().x + threadIdx().x col = @view A[:, i] - out[i] = SparseArrays.nnz(col) + out[i] = nnz(col) nothing end @@ -43,7 +44,7 @@ end out end - nnz_per_column(A::SparseMatrixCSC) = map(SparseArrays.nnz, eachcol(A)) + nnz_per_column(A::SparseMatrixCSC) = map(nnz, eachcol(A)) A = sprand(10, 10, 0.5) cuA = CuSparseMatrixCSC(A) @@ -59,7 +60,7 @@ end v = zero(Tv) i = threadIdx().x - while i <= SparseArrays.nnz(col) + while i <= nnz(col) v += nonzeros(col)[i] i += blockDim().x end @@ -89,7 +90,7 @@ end function last_nz_per_column_kernel(out, A) i = (blockIdx().x - 1) * blockDim().x + threadIdx().x col = @view A[:, i] - out[i] = last(SparseArrays.rowvals(col)) + out[i] = last(rowvals(col)) nothing end @@ -105,4 +106,4 @@ end @test last_nz_per_column(A) == Vector(last_nz_per_column(cuA)) end -end \ No newline at end of file +end From 293828c6b8cf8e5d35c16509552d6d7c753b2b3b Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 22 Oct 2025 08:22:25 -0400 Subject: [PATCH 4/5] Keep abreast of GPUArrays branch --- lib/cusparse/array.jl | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/lib/cusparse/array.jl b/lib/cusparse/array.jl index 631e8ef8a0..f97b2d9542 100644 --- a/lib/cusparse/array.jl +++ b/lib/cusparse/array.jl @@ -95,25 +95,25 @@ function CUDA.unsafe_free!(xs::CuSparseMatrixCSR) return end -GPUArrays._sparse_array_type(sa::CuSparseMatrixCSC) = CuSparseMatrixCSC -GPUArrays._sparse_array_type(::Type{<:CuSparseMatrixCSC}) = CuSparseMatrixCSC -GPUArrays._sparse_array_type(sa::CuSparseMatrixCSR) = CuSparseMatrixCSR -GPUArrays._sparse_array_type(::Type{<:CuSparseMatrixCSR}) = CuSparseMatrixCSR -GPUArrays._sparse_array_type(sa::CuSparseVector) = CuSparseVector -GPUArrays._sparse_array_type(::Type{<:CuSparseVector}) = CuSparseVector -GPUArrays._dense_array_type(sa::CuSparseVector) = CuArray -GPUArrays._dense_array_type(::Type{<:CuSparseVector}) = CuArray -GPUArrays._dense_array_type(sa::CuSparseMatrixCSC) = CuArray -GPUArrays._dense_array_type(::Type{<:CuSparseMatrixCSC}) = CuArray -GPUArrays._dense_array_type(sa::CuSparseMatrixCSR) = CuArray -GPUArrays._dense_array_type(::Type{<:CuSparseMatrixCSR}) = CuArray - -GPUArrays._csc_type(sa::CuSparseMatrixCSR) = CuSparseMatrixCSC -GPUArrays._csr_type(sa::CuSparseMatrixCSC) = CuSparseMatrixCSR -GPUArrays._coo_type(sa::Union{CuSparseMatrixCSR, Transpose{<:Any,<:CuSparseMatrixCSR}, Adjoint{<:Any,<:CuSparseMatrixCSR}}) = CuSparseMatrixCOO -GPUArrays._coo_type(sa::Union{CuSparseMatrixCSC, Transpose{<:Any,<:CuSparseMatrixCSC}, Adjoint{<:Any,<:CuSparseMatrixCSC}}) = CuSparseMatrixCOO -GPUArrays._coo_type(::Type{T}) where {T<:Union{CuSparseMatrixCSR, Transpose{<:Any,<:CuSparseMatrixCSR}, Adjoint{<:Any,<:CuSparseMatrixCSR}}} = CuSparseMatrixCOO -GPUArrays._coo_type(::Type{T}) where {T<:Union{CuSparseMatrixCSC, Transpose{<:Any,<:CuSparseMatrixCSC}, Adjoint{<:Any,<:CuSparseMatrixCSC}}} = CuSparseMatrixCOO +GPUArrays.sparse_array_type(sa::CuSparseMatrixCSC) = CuSparseMatrixCSC +GPUArrays.sparse_array_type(::Type{<:CuSparseMatrixCSC}) = CuSparseMatrixCSC +GPUArrays.sparse_array_type(sa::CuSparseMatrixCSR) = CuSparseMatrixCSR +GPUArrays.sparse_array_type(::Type{<:CuSparseMatrixCSR}) = CuSparseMatrixCSR +GPUArrays.sparse_array_type(sa::CuSparseVector) = CuSparseVector +GPUArrays.sparse_array_type(::Type{<:CuSparseVector}) = CuSparseVector +GPUArrays.dense_array_type(sa::CuSparseVector) = CuArray +GPUArrays.dense_array_type(::Type{<:CuSparseVector}) = CuArray +GPUArrays.dense_array_type(sa::CuSparseMatrixCSC) = CuArray +GPUArrays.dense_array_type(::Type{<:CuSparseMatrixCSC}) = CuArray +GPUArrays.dense_array_type(sa::CuSparseMatrixCSR) = CuArray +GPUArrays.dense_array_type(::Type{<:CuSparseMatrixCSR}) = CuArray + +GPUArrays.csc_type(sa::CuSparseMatrixCSR) = CuSparseMatrixCSC +GPUArrays.csr_type(sa::CuSparseMatrixCSC) = CuSparseMatrixCSR +GPUArrays.coo_type(sa::Union{CuSparseMatrixCSR, Transpose{<:Any,<:CuSparseMatrixCSR}, Adjoint{<:Any,<:CuSparseMatrixCSR}}) = CuSparseMatrixCOO +GPUArrays.coo_type(sa::Union{CuSparseMatrixCSC, Transpose{<:Any,<:CuSparseMatrixCSC}, Adjoint{<:Any,<:CuSparseMatrixCSC}}) = CuSparseMatrixCOO +GPUArrays.coo_type(::Type{T}) where {T<:Union{CuSparseMatrixCSR, Transpose{<:Any,<:CuSparseMatrixCSR}, Adjoint{<:Any,<:CuSparseMatrixCSR}}} = CuSparseMatrixCOO +GPUArrays.coo_type(::Type{T}) where {T<:Union{CuSparseMatrixCSC, Transpose{<:Any,<:CuSparseMatrixCSC}, Adjoint{<:Any,<:CuSparseMatrixCSC}}} = CuSparseMatrixCOO """ Container to hold sparse matrices in block compressed sparse row (BSR) format on From 365ec16288f8a7d3f77e5ce1857fcee269fbaa44 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 24 Oct 2025 03:16:44 -0400 Subject: [PATCH 5/5] More updates (mostly clearing out) --- lib/cusparse/array.jl | 39 --------------------------------------- 1 file changed, 39 deletions(-) diff --git a/lib/cusparse/array.jl b/lib/cusparse/array.jl index f97b2d9542..d3e9f0e506 100644 --- a/lib/cusparse/array.jl +++ b/lib/cusparse/array.jl @@ -567,19 +567,6 @@ SparseArrays.SparseMatrixCSC(x::CuSparseMatrixCSR) = SparseMatrixCSC(CuSparseMat SparseArrays.SparseMatrixCSC(x::CuSparseMatrixBSR) = SparseMatrixCSC(CuSparseMatrixCSR(x)) # no direct conversion (gpu_BSR -> gpu_CSR -> gpu_CSC -> cpu_CSC) SparseArrays.SparseMatrixCSC(x::CuSparseMatrixCOO) = SparseMatrixCSC(CuSparseMatrixCSC(x)) # no direct conversion (gpu_COO -> gpu_CSC -> cpu_CSC) -# collect to Array -Base.collect(x::CuSparseVector) = collect(SparseVector(x)) -Base.collect(x::CuSparseMatrixCSC) = collect(SparseMatrixCSC(x)) -Base.collect(x::CuSparseMatrixCSR) = collect(SparseMatrixCSC(x)) -Base.collect(x::CuSparseMatrixBSR) = collect(SparseMatrixCSC(x)) -Base.collect(x::CuSparseMatrixCOO) = collect(SparseMatrixCSC(x)) - -Base.Array(x::CuSparseVector) = collect(SparseVector(x)) -Base.Array(x::CuSparseMatrixCSC) = collect(SparseMatrixCSC(x)) -Base.Array(x::CuSparseMatrixCSR) = collect(SparseMatrixCSC(x)) -Base.Array(x::CuSparseMatrixBSR) = collect(SparseMatrixCSC(x)) -Base.Array(x::CuSparseMatrixCOO) = collect(SparseMatrixCSC(x)) - Adapt.adapt_storage(::Type{CuArray}, xs::SparseVector) = CuSparseVector(xs) Adapt.adapt_storage(::Type{CuArray}, xs::SparseMatrixCSC) = CuSparseMatrixCSC(xs) Adapt.adapt_storage(::Type{CuArray{T}}, xs::SparseVector) where {T} = CuSparseVector{T}(xs) @@ -596,32 +583,6 @@ Adapt.adapt_storage(::Type{Array}, xs::CuSparseMatrixCSC) = SparseMatrixCSC(xs) ## copying between sparse GPU arrays -function Base.copyto!(dst::CuSparseVector, src::CuSparseVector) - if length(dst) != length(src) - throw(ArgumentError("Inconsistent Sparse Vector size")) - end - resize!(nonzeroinds(dst), length(nonzeroinds(src))) - resize!(nonzeros(dst), length(nonzeros(src))) - copyto!(nonzeroinds(dst), nonzeroinds(src)) - copyto!(nonzeros(dst), nonzeros(src)) - dst.nnz = src.nnz - dst -end - -function Base.copyto!(dst::CuSparseMatrixCSC, src::CuSparseMatrixCSC) - if size(dst) != size(src) - throw(ArgumentError("Inconsistent Sparse Matrix size")) - end - resize!(dst.colPtr, length(src.colPtr)) - resize!(rowvals(dst), length(rowvals(src))) - resize!(nonzeros(dst), length(nonzeros(src))) - copyto!(dst.colPtr, src.colPtr) - copyto!(rowvals(dst), rowvals(src)) - copyto!(nonzeros(dst), nonzeros(src)) - dst.nnz = src.nnz - dst -end - function Base.copyto!(dst::CuSparseMatrixCSR, src::CuSparseMatrixCSR) if size(dst) != size(src) throw(ArgumentError("Inconsistent Sparse Matrix size"))