Skip to content
20 changes: 14 additions & 6 deletions sparse/_compressed/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
can_store,
check_zero_fill_value,
check_compressed_axes,
_zero_of_dtype,
equivalent,
)
from .._coo.core import COO
Expand Down Expand Up @@ -143,7 +144,7 @@ def __init__(
shape=None,
compressed_axes=None,
prune=False,
fill_value=0,
fill_value=None,
idx_dtype=None,
):
if isinstance(arg, ss.spmatrix):
Expand All @@ -169,6 +170,10 @@ def __init__(
arg.fill_value,
)

self.data, self.indices, self.indptr = arg

if fill_value is None:
fill_value = _zero_of_dtype(self.data.dtype)
if shape is None:
raise ValueError("missing `shape` argument")

Expand All @@ -177,8 +182,6 @@ def __init__(
if len(shape) == 1:
compressed_axes = None

self.data, self.indices, self.indptr = arg

if self.data.ndim != 1:
raise ValueError("data must be a scalar or 1-dimensional.")

Expand Down Expand Up @@ -845,7 +848,12 @@ def _prune(self):

class _Compressed2d(GCXS):
def __init__(
self, arg, shape=None, compressed_axes=None, prune=False, fill_value=0
self,
arg,
shape=None,
compressed_axes=None,
prune=False,
fill_value=None,
):
if not hasattr(arg, "shape") and shape is None:
raise ValueError("missing `shape` argument")
Expand Down Expand Up @@ -888,7 +896,7 @@ class CSR(_Compressed2d):
Sparse supports 2-D CSR.
"""

def __init__(self, arg, shape=None, prune=False, fill_value=0):
def __init__(self, arg, shape=None, prune=False, fill_value=None):
super().__init__(arg, shape=shape, compressed_axes=(0,), fill_value=fill_value)

@classmethod
Expand All @@ -913,7 +921,7 @@ class CSC(_Compressed2d):
Sparse supports 2-D CSC.
"""

def __init__(self, arg, shape=None, prune=False, fill_value=0):
def __init__(self, arg, shape=None, prune=False, fill_value=None):
super().__init__(arg, shape=shape, compressed_axes=(1,), fill_value=fill_value)

@classmethod
Expand Down
155 changes: 155 additions & 0 deletions sparse/_compressed/elemwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from functools import lru_cache
from typing import Callable

import numpy as np
import scipy.sparse
from numba import njit

from .compressed import _Compressed2d


def op_unary(func, a):
res = a.copy()
res.data = func(a.data)
return res

Check warning on line 14 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L12-L14

Added lines #L12 - L14 were not covered by tests


@lru_cache(maxsize=None)
def _numba_d(func):
return njit(lambda *x: func(*x))


def binary_op(func, a, b):
func = _numba_d(func)
if isinstance(a, _Compressed2d) and isinstance(b, _Compressed2d):
return op_union_indices(func, a, b)
else:
raise NotImplementedError()

# From scipy._util
def _prune_array(array):
"""Return an array equivalent to the input array. If the input
array is a view of a much larger array, copy its contents to a
newly allocated array. Otherwise, return the input unchanged.
"""
if array.base is not None and array.size < array.base.size // 2:
return array.copy()
return array



def op_union_indices(
op: Callable, a: scipy.sparse.csr_matrix, b: scipy.sparse.csr_matrix, *, default_value=0
):
assert a.shape == b.shape

if type(a) != type(b):
b = type(a)(b)

Check warning on line 47 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L47

Added line #L47 was not covered by tests
# a.sort_indices()
# b.sort_indices()

# TODO: numpy is weird with bools here
out_dtype = np.array(op(a.data[0], b.data[0])).dtype
default_value = out_dtype.type(default_value)
out_indptr = np.zeros_like(a.indptr)
out_indices = np.zeros(len(a.indices) + len(b.indices), dtype=np.promote_types(a.indices.dtype, b.indices.dtype))
out_data = np.zeros(len(out_indices), dtype=out_dtype)

nnz = op_union_indices_csr_csr(
op,
a.indptr,
a.indices,
a.data,
b.indptr,
b.indices,
b.data,
out_indptr,
out_indices,
out_data,
out_dtype=out_dtype,
default_value=default_value,
)
out_data = _prune_array(out_data[:nnz])
out_indices = _prune_array(out_indices[:nnz])
return type(a)((out_data, out_indices, out_indptr), shape=a.shape)


@njit
def op_union_indices_csr_csr(
op: Callable,
a_indptr: np.ndarray,
a_indices: np.ndarray,
a_data: np.ndarray,
b_indptr: np.ndarray,
b_indices: np.ndarray,
b_data: np.ndarray,
out_indptr: np.ndarray,
out_indices: np.ndarray,
out_data: np.ndarray,
out_dtype,
default_value,
):
# out_indptr = np.zeros_like(a_indptr)
# out_indices = np.zeros(len(a_indices) + len(b_indices), dtype=a_indices.dtype)
# out_data = np.zeros(len(out_indices), dtype=out_dtype)

out_idx = 0

Check warning on line 96 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L96

Added line #L96 was not covered by tests

for i in range(len(a_indptr) - 1):

Check warning on line 98 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L98

Added line #L98 was not covered by tests

a_idx = a_indptr[i]
a_end = a_indptr[i + 1]
b_idx = b_indptr[i]
b_end = b_indptr[i + 1]

Check warning on line 103 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L100-L103

Added lines #L100 - L103 were not covered by tests

while (a_idx < a_end) and (b_idx < b_end):
a_j = a_indices[a_idx]
b_j = b_indices[b_idx]
if a_j < b_j:
val = op(a_data[a_idx], default_value)
if val != default_value:
out_indices[out_idx] = a_j
out_data[out_idx] = val
out_idx += 1
a_idx += 1
elif b_j < a_j:
val = op(default_value, b_data[b_idx])
if val != default_value:
out_indices[out_idx] = b_j
out_data[out_idx] = val
out_idx += 1
b_idx += 1

Check warning on line 121 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L105-L121

Added lines #L105 - L121 were not covered by tests
else:
val = op(a_data[a_idx], b_data[b_idx])
if val != default_value:
out_indices[out_idx] = a_j
out_data[out_idx] = val
out_idx += 1
a_idx += 1
b_idx += 1

Check warning on line 129 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L123-L129

Added lines #L123 - L129 were not covered by tests

# Catch up the other set
while a_idx < a_end:
val = op(a_data[a_idx], default_value)
if val != default_value:
out_indices[out_idx] = a_indices[a_idx]
out_data[out_idx] = val
out_idx += 1
a_idx += 1

Check warning on line 138 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L132-L138

Added lines #L132 - L138 were not covered by tests

while b_idx < b_end:
val = op(default_value, b_data[b_idx])
if val != default_value:
out_indices[out_idx] = b_indices[b_idx]
out_data[out_idx] = val
out_idx += 1
b_idx += 1

Check warning on line 146 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L140-L146

Added lines #L140 - L146 were not covered by tests

out_indptr[i + 1] = out_idx

Check warning on line 148 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L148

Added line #L148 was not covered by tests

# This may need to change to be "resize" to allow memory reallocation
# resize is currently not implemented in numba
out_indices = out_indices[: out_idx]
out_data = out_data[: out_idx]

Check warning on line 153 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L152-L153

Added lines #L152 - L153 were not covered by tests

return out_idx

Check warning on line 155 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L155

Added line #L155 was not covered by tests
110 changes: 92 additions & 18 deletions sparse/_umath.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,48 @@
)


# TODO: Figure out the right way to type this
# TODO: Figure out how to do 1d COO + CSR or CSC
def _resolve_result_type(args: "list[ArrayLike]") -> "Type":
from ._compressed import GCXS, CSR, CSC
from ._coo import COO
from ._dok import DOK
from ._sparse_array import SparseArray
from ._compressed.compressed import _Compressed2d

args = [arg for arg in args if isinstance(arg, SparseArray)]

if all(isinstance(arg, DOK) for arg in args):
out_type = DOK
elif all(isinstance(arg, CSR) for arg in args):
out_type = CSR
elif all(isinstance(arg, CSC) for arg in args):
out_type = CSC
elif all(isinstance(arg, _Compressed2d) for arg in args):
out_type = CSR
elif all(isinstance(arg, GCXS) for arg in args):
out_type = GCXS
else:
out_type = COO
return out_type


def _from_scipy_sparse(a):
from ._compressed import CSR, CSC
from ._coo import COO
from ._dok import DOK

assert isinstance(a, scipy.sparse.spmatrix)
if isinstance(a, scipy.sparse.csr_matrix):
return CSR(a)
elif isinstance(a, scipy.sparse.csc_matrix):
return CSC(a)
elif isinstance(a, scipy.sparse.dok_matrix):
return DOK(a.shape, data=dict(a))
else:
return COO(a)


class _Elemwise:
def __init__(self, func, *args, **kwargs):
"""
Expand All @@ -423,24 +465,26 @@
"""
from ._coo import COO
from ._sparse_array import SparseArray
from ._compressed import GCXS
from ._compressed import GCXS, CSR, CSC
from ._compressed.compressed import _Compressed2d
from ._dok import DOK

processed_args = []
out_type = GCXS

sparse_args = [arg for arg in args if isinstance(arg, SparseArray)]
args = [
arg
if not isinstance(arg, scipy.sparse.spmatrix)
else _from_scipy_sparse(arg)
for arg in args
]

if all(isinstance(arg, DOK) for arg in sparse_args):
out_type = DOK
elif all(isinstance(arg, GCXS) for arg in sparse_args):
out_type = GCXS
else:
out_type = COO
processed_args = []

self.out_type = _resolve_result_type(args)
# Should this happen before dispatch?
# Hmm, this may need major major changes.
# Case to consider: CSR or CSC + 1d COO
for arg in args:
if isinstance(arg, scipy.sparse.spmatrix):
processed_args.append(COO.from_scipy_sparse(arg))
if self.out_type != COO and isinstance(arg, _Compressed2d):
processed_args.append(arg)
elif isscalar(arg) or isinstance(arg, np.ndarray):
# Faster and more reliable to pass ()-shaped ndarrays as scalars.
processed_args.append(np.asarray(arg))
Expand All @@ -454,7 +498,6 @@
self.args = None
return

self.out_type = out_type
self.args = tuple(processed_args)
self.func = func
self.dtype = kwargs.pop("dtype", None)
Expand All @@ -467,14 +510,19 @@

def get_result(self):
from ._coo import COO
from ._sparse_array import SparseArray
from ._compressed.compressed import _Compressed2d

if self.args is None:
return NotImplemented

if self._dense_result:
args = [a.todense() if isinstance(a, COO) else a for a in self.args]
args = [a.todense() if isinstance(a, SparseArray) else a for a in self.args]
return self.func(*args, **self.kwargs)

if issubclass(self.out_type, _Compressed2d):
return self._get_result_compressed_2d()

if any(s == 0 for s in self.shape):
data = np.empty((0,), dtype=self.fill_value.dtype)
coords = np.empty((0, len(self.shape)), dtype=np.intp)
Expand Down Expand Up @@ -521,6 +569,29 @@
fill_value=self.fill_value,
).asformat(self.out_type)

def _get_result_compressed_2d(self):
from ._compressed import elemwise as elemwise2d
from ._compressed.compressed import _Compressed2d

if len(self.args) == 1:
result = elemwise2d.op_unary(self.func, self.args[0])

Check warning on line 577 in sparse/_umath.py

View check run for this annotation

Codecov / codecov/patch

sparse/_umath.py#L577

Added line #L577 was not covered by tests

processed_args = []
for arg in self.args:
if isinstance(arg, self.out_type):
processed_args.append(arg)
elif isinstance(arg, _Compressed2d):
processed_args.append(self.out_type(arg))
elif isinstance(arg, np.ndarray):
processed_args.append(np.broadcast_to(arg, self.shape))
else:
raise NotImplementedError()

if len(processed_args) == 2:
result = elemwise2d.binary_op(self.func, *processed_args)

return result

def _get_fill_value(self):
"""
A function that finds and returns the fill-value.
Expand All @@ -530,10 +601,11 @@
ValueError
If the fill-value is inconsistent.
"""
from ._coo import COO
from ._sparse_array import SparseArray

zero_args = tuple(
arg.fill_value[...] if isinstance(arg, COO) else arg for arg in self.args
arg.fill_value[...] if isinstance(arg, SparseArray) else arg
for arg in self.args
)

# Some elemwise functions require a dtype argument, some abhorr it.
Expand All @@ -550,7 +622,9 @@
fill_value = fill_value_array[(0,) * fill_value_array.ndim]
except IndexError:
zero_args = tuple(
arg.fill_value if isinstance(arg, COO) else _zero_of_dtype(arg.dtype)
arg.fill_value
if isinstance(arg, SparseArray)
else _zero_of_dtype(arg.dtype)
for arg in self.args
)
fill_value = self.func(*zero_args, **self.kwargs)[()]
Expand Down
Loading